Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit ca10b8e

Browse files
chenyangyu1988facebook-github-bot
authored andcommitted
use ScriptXLMTensorizer (#1123)
Summary: Pull Request resolved: #1123 use ScriptXLMTensorizer Reviewed By: rutyrinott Differential Revision: D18364254 fbshipit-source-id: 1d607288f9e10c909f7a42fc2a5d53d94ca2e815
1 parent b3f75cf commit ca10b8e

File tree

2 files changed

+71
-2
lines changed

2 files changed

+71
-2
lines changed

pytext/data/xlm_tensorizer.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from pytext.data.tokenizers import Tokenizer
1212
from pytext.data.utils import EOS, MASK, PAD, UNK, Vocabulary
1313
from pytext.data.xlm_constants import LANG2ID_15
14+
from pytext.torchscript.tensorizer import ScriptXLMTensorizer
15+
from pytext.torchscript.vocab import ScriptVocabulary
1416

1517

1618
class XLMTensorizer(BERTTensorizerBase):
@@ -85,6 +87,7 @@ def __init__(
8587
# unlike BERT, XLM uses the EOS token for both beginning and end of
8688
# sentence
8789
self.bos_token = self.vocab.eos_token
90+
self.default_language = "en"
8891

8992
@property
9093
def column_schema(self):
@@ -103,7 +106,7 @@ def get_lang_id(self, row: Dict, col: str) -> int:
103106
return lang_id
104107
else:
105108
# use En as default
106-
return self.lang2id.get("en", 0)
109+
return self.lang2id.get(self.default_language, 0)
107110

108111
def _lookup_tokens(self, text: str, seq_len: int) -> List[str]:
109112
return lookup_tokens(
@@ -137,3 +140,22 @@ def numberize(self, row: Dict) -> Tuple[Any, ...]:
137140
seq_len = len(tokens)
138141
positions = [index for index in range(seq_len)]
139142
return tokens, segment_labels, seq_len, positions
143+
144+
def torchscriptify(self):
145+
languages = [0] * (max(list(self.lang2id.values())) + 1)
146+
for k, v in self.lang2id.items():
147+
languages[v] = k
148+
149+
return ScriptXLMTensorizer(
150+
tokenizer=self.tokenizer.torchscriptify(),
151+
token_vocab=ScriptVocabulary(
152+
list(self.vocab),
153+
pad_idx=self.vocab.get_pad_index(),
154+
bos_idx=self.vocab.get_eos_index(),
155+
eos_idx=self.vocab.get_eos_index(),
156+
unk_idx=self.vocab.get_unk_index(),
157+
),
158+
language_vocab=ScriptVocabulary(languages),
159+
max_seq_len=self.max_seq_len,
160+
default_language=self.default_language,
161+
)

pytext/torchscript/module.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
33

4-
from typing import List
4+
from typing import List, Optional
55

66
import torch
77
from pytext.torchscript.tensorizer.tensorizer import ScriptTensorizer
@@ -53,3 +53,50 @@ def forward(self, tokens: List[List[str]]):
5353
input_tensors = self.tensorizer.tensorize(tokens=squeeze_2d(tokens))
5454
logits = self.model(input_tensors)
5555
return self.output_layer(logits)
56+
57+
58+
class ScriptTokenLanguageModule(torch.jit.ScriptModule):
59+
def __init__(
60+
self,
61+
model: torch.jit.ScriptModule,
62+
output_layer: torch.jit.ScriptModule,
63+
tensorizer: ScriptTensorizer,
64+
):
65+
super().__init__()
66+
self.model = model
67+
self.output_layer = output_layer
68+
self.tensorizer = tensorizer
69+
70+
@torch.jit.script_method
71+
def forward(self, tokens: List[List[str]], languages: Optional[List[str]] = None):
72+
input_tensors = self.tensorizer.tensorize(
73+
tokens=squeeze_2d(tokens), languages=squeeze_1d(languages)
74+
)
75+
logits = self.model(input_tensors)
76+
return self.output_layer(logits)
77+
78+
79+
class ScriptTokenLanguageModuleWithDenseFeature(torch.jit.ScriptModule):
80+
def __init__(
81+
self,
82+
model: torch.jit.ScriptModule,
83+
output_layer: torch.jit.ScriptModule,
84+
tensorizer: ScriptTensorizer,
85+
):
86+
super().__init__()
87+
self.model = model
88+
self.output_layer = output_layer
89+
self.tensorizer = tensorizer
90+
91+
@torch.jit.script_method
92+
def forward(
93+
self,
94+
tokens: List[List[str]],
95+
dense_feat: List[List[float]],
96+
languages: Optional[List[str]] = None,
97+
):
98+
input_tensors = self.tensorizer.tensorize(
99+
tokens=squeeze_2d(tokens), languages=squeeze_1d(languages)
100+
)
101+
logits = self.model(input_tensors, torch.tensor(dense_feat).float())
102+
return self.output_layer(logits)

0 commit comments

Comments
 (0)