Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit c1b4155

Browse files
snisargfacebook-github-bot
authored andcommitted
Scripted tokenizer support for DocModel (#1314)
Summary: Pull Request resolved: #1314 Adding scripted tokenization support to the most widely used model OSS Test failures Waiting for a TorchScript diff to land: https://fb.workplace.com/groups/329222650990087/permalink/632527153992967/ Differential Revision: D20955370 fbshipit-source-id: 6e002136fc0113dfe87a24ca3f82be9a73d1d0bc
1 parent 522e079 commit c1b4155

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

pytext/models/doc_model.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
TokenTensorizer,
1515
UidTensorizer,
1616
)
17+
from pytext.data.tokenizers import DoNothingTokenizer
1718
from pytext.data.utils import PAD, UNK
1819
from pytext.exporters.exporter import ModelExporter
1920
from pytext.loss import BinaryCrossEntropyLoss, MultiLabelSoftMarginLoss
@@ -119,6 +120,13 @@ def torchscriptify(self, tensorizers, traced_model):
119120

120121
input_vocab = tensorizers["tokens"].vocab
121122
max_seq_len = tensorizers["tokens"].max_seq_len or -1
123+
scripted_tokenizer = None
124+
try:
125+
scripted_tokenizer = tensorizers["tokens"].tokenizer.torchscriptify()
126+
except NotImplementedError:
127+
pass
128+
if scripted_tokenizer and isinstance(scripted_tokenizer, DoNothingTokenizer):
129+
scripted_tokenizer = None
122130

123131
"""
124132
The input tensor packing memory is allocated/cached for different shapes,
@@ -139,6 +147,7 @@ def __init__(self):
139147
self.output_layer = output_layer
140148
self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int)
141149
self.max_seq_len = jit.Attribute(max_seq_len, int)
150+
self.tokenizer = scripted_tokenizer
142151

143152
@jit.script_method
144153
def forward(
@@ -148,6 +157,16 @@ def forward(
148157
tokens: Optional[List[List[str]]] = None,
149158
languages: Optional[List[str]] = None,
150159
):
160+
# PyTorch breaks with 2 'not None' checks right now.
161+
if texts is not None:
162+
if tokens is not None:
163+
raise RuntimeError("Can't set both tokens and texts")
164+
if self.tokenizer is not None:
165+
tokens = [
166+
[t[0] for t in self.tokenizer.tokenize(text)]
167+
for text in texts
168+
]
169+
151170
if tokens is None:
152171
raise RuntimeError("tokens is required")
153172

@@ -171,6 +190,7 @@ def __init__(self):
171190
self.output_layer = output_layer
172191
self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int)
173192
self.max_seq_len = jit.Attribute(max_seq_len, int)
193+
self.tokenizer = scripted_tokenizer
174194

175195
@jit.script_method
176196
def forward(
@@ -181,6 +201,16 @@ def forward(
181201
languages: Optional[List[str]] = None,
182202
dense_feat: Optional[List[List[float]]] = None,
183203
):
204+
# PyTorch breaks with 2 'not None' checks right now.
205+
if texts is not None:
206+
if tokens is not None:
207+
raise RuntimeError("Can't set both tokens and texts")
208+
if self.tokenizer is not None:
209+
tokens = [
210+
[t[0] for t in self.tokenizer.tokenize(text)]
211+
for text in texts
212+
]
213+
184214
if tokens is None:
185215
raise RuntimeError("tokens is required")
186216
if dense_feat is None:

0 commit comments

Comments
 (0)