Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit 9d9f784

Browse files
chenyangyu1988facebook-github-bot
authored andcommitted
use ScriptXLMTensorizer (#1123)
Summary: Pull Request resolved: #1123 use ScriptXLMTensorizer Differential Revision: D18364254 fbshipit-source-id: 20966a39aa3631cd84cfb9a778bd19c7f8d03cc8
1 parent 8e44fb5 commit 9d9f784

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

pytext/data/xlm_tensorizer.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from pytext.data.tokenizers import Tokenizer
1212
from pytext.data.utils import EOS, MASK, PAD, UNK, Vocabulary
1313
from pytext.data.xlm_constants import LANG2ID_15
14+
from pytext.torchscript.tensorizer import ScriptXLMTensorizer
15+
from pytext.torchscript.vocab import ScriptVocabulary
1416

1517

1618
class XLMTensorizer(BERTTensorizerBase):
@@ -137,3 +139,23 @@ def numberize(self, row: Dict) -> Tuple[Any, ...]:
137139
seq_len = len(tokens)
138140
positions = [index for index in range(seq_len)]
139141
return tokens, segment_labels, seq_len, positions
142+
143+
def torchscriptify(self, languages=None, default_language="en"):
144+
if languages is None:
145+
languages = [0] * (max(list(self.lang2id.values())) + 1)
146+
for k, v in self.lang2id.items():
147+
languages[v] = k
148+
149+
return ScriptXLMTensorizer(
150+
tokenizer=self.tokenizer.torchscriptify(),
151+
token_vocab=ScriptVocabulary(
152+
list(self.vocab),
153+
pad_idx=self.vocab.get_pad_index(),
154+
bos_idx=self.vocab.get_eos_index(),
155+
eos_idx=self.vocab.get_eos_index(),
156+
unk_idx=self.vocab.get_unk_index(),
157+
),
158+
language_vocab=ScriptVocabulary(languages),
159+
max_seq_len=self.max_seq_len,
160+
default_language=default_language,
161+
)

0 commit comments

Comments
 (0)