Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit e17514d

Browse files
snisargfacebook-github-bot
authored andcommitted
Scripted tokenizer support for DocModel (#1314)
Summary: Pull Request resolved: #1314 Adding scripted tokenization support to the most widely used model Differential Revision: D20955370 fbshipit-source-id: 91c60db42c8eb8deac8afb573942053ac9555e99
1 parent f2b55ff commit e17514d

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

pytext/models/doc_model.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
TokenTensorizer,
1515
UidTensorizer,
1616
)
17+
from pytext.data.tokenizers import DoNothingTokenizer
1718
from pytext.data.utils import PAD, UNK
1819
from pytext.exporters.exporter import ModelExporter
1920
from pytext.loss import BinaryCrossEntropyLoss, MultiLabelSoftMarginLoss
@@ -115,6 +116,13 @@ def torchscriptify(self, tensorizers, traced_model):
115116

116117
input_vocab = tensorizers["tokens"].vocab
117118
max_seq_len = tensorizers["tokens"].max_seq_len or -1
119+
scripted_tokenizer = None
120+
try:
121+
scripted_tokenizer = tensorizers["tokens"].tokenizer.torchscriptify()
122+
except NotImplementedError:
123+
pass
124+
if scripted_tokenizer and isinstance(scripted_tokenizer, DoNothingTokenizer):
125+
scripted_tokenizer = None
118126

119127
"""
120128
The input tensor packing memory is allocated/cached for different shapes,
@@ -135,6 +143,7 @@ def __init__(self):
135143
self.output_layer = output_layer
136144
self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int)
137145
self.max_seq_len = jit.Attribute(max_seq_len, int)
146+
self.tokenizer = scripted_tokenizer
138147

139148
@jit.script_method
140149
def forward(
@@ -144,6 +153,13 @@ def forward(
144153
tokens: Optional[List[List[str]]] = None,
145154
languages: Optional[List[str]] = None,
146155
):
156+
if texts is not None and tokens is not None:
157+
raise RuntimeError("Can't set both tokens and texts")
158+
if self.tokenizer is not None and texts is not None:
159+
tokens = [
160+
[t[0] for t in self.tokenizer.tokenize(text)] for text in texts
161+
]
162+
147163
if tokens is None:
148164
raise RuntimeError("tokens is required")
149165

@@ -167,6 +183,7 @@ def __init__(self):
167183
self.output_layer = output_layer
168184
self.pad_idx = jit.Attribute(input_vocab.get_pad_index(), int)
169185
self.max_seq_len = jit.Attribute(max_seq_len, int)
186+
self.tokenizer = scripted_tokenizer
170187

171188
@jit.script_method
172189
def forward(
@@ -177,6 +194,13 @@ def forward(
177194
languages: Optional[List[str]] = None,
178195
dense_feat: Optional[List[List[float]]] = None,
179196
):
197+
if texts is not None and tokens is not None:
198+
raise RuntimeError("Can't set both tokens and texts")
199+
if self.tokenizer is not None and texts is not None:
200+
tokens = [
201+
[t[0] for t in self.tokenizer.tokenize(text)] for text in texts
202+
]
203+
180204
if tokens is None:
181205
raise RuntimeError("tokens is required")
182206
if dense_feat is None:

0 commit comments

Comments
 (0)