|
11 | 11 | from pytext.data.tokenizers import Tokenizer
|
12 | 12 | from pytext.data.utils import EOS, MASK, PAD, UNK, Vocabulary
|
13 | 13 | from pytext.data.xlm_constants import LANG2ID_15
|
| 14 | +from pytext.torchscript.tensorizer import ScriptXLMTensorizer |
| 15 | +from pytext.torchscript.vocab import ScriptVocabulary |
14 | 16 |
|
15 | 17 |
|
16 | 18 | class XLMTensorizer(BERTTensorizerBase):
|
@@ -137,3 +139,23 @@ def numberize(self, row: Dict) -> Tuple[Any, ...]:
|
137 | 139 | seq_len = len(tokens)
|
138 | 140 | positions = [index for index in range(seq_len)]
|
139 | 141 | return tokens, segment_labels, seq_len, positions
|
| 142 | + |
| 143 | + def torchscriptify(self, languages=None, default_language="en"): |
| 144 | + if languages is None: |
| 145 | + languages = [0] * (max(list(self.lang2id.values())) + 1) |
| 146 | + for k, v in self.lang2id.items(): |
| 147 | + languages[v] = k |
| 148 | + |
| 149 | + return ScriptXLMTensorizer( |
| 150 | + tokenizer=self.tokenizer.torchscriptify(), |
| 151 | + token_vocab=ScriptVocabulary( |
| 152 | + list(self.vocab), |
| 153 | + pad_idx=self.vocab.get_pad_index(), |
| 154 | + bos_idx=self.vocab.get_eos_index(), |
| 155 | + eos_idx=self.vocab.get_eos_index(), |
| 156 | + unk_idx=self.vocab.get_unk_index(), |
| 157 | + ), |
| 158 | + language_vocab=ScriptVocabulary(languages), |
| 159 | + max_seq_len=self.max_seq_len, |
| 160 | + default_language=default_language, |
| 161 | + ) |
0 commit comments