14
14
TokenTensorizer ,
15
15
UidTensorizer ,
16
16
)
17
+ from pytext .data .tokenizers import DoNothingTokenizer
17
18
from pytext .data .utils import PAD , UNK
18
19
from pytext .exporters .exporter import ModelExporter
19
20
from pytext .loss import BinaryCrossEntropyLoss , MultiLabelSoftMarginLoss
@@ -115,6 +116,13 @@ def torchscriptify(self, tensorizers, traced_model):
115
116
116
117
input_vocab = tensorizers ["tokens" ].vocab
117
118
max_seq_len = tensorizers ["tokens" ].max_seq_len or - 1
119
+ scripted_tokenizer = None
120
+ try :
121
+ scripted_tokenizer = tensorizers ["tokens" ].tokenizer .torchscriptify ()
122
+ except NotImplementedError :
123
+ pass
124
+ if scripted_tokenizer and isinstance (scripted_tokenizer , DoNothingTokenizer ):
125
+ scripted_tokenizer = None
118
126
119
127
"""
120
128
The input tensor packing memory is allocated/cached for different shapes,
@@ -135,6 +143,7 @@ def __init__(self):
135
143
self .output_layer = output_layer
136
144
self .pad_idx = jit .Attribute (input_vocab .get_pad_index (), int )
137
145
self .max_seq_len = jit .Attribute (max_seq_len , int )
146
+ self .tokenizer = scripted_tokenizer
138
147
139
148
@jit .script_method
140
149
def forward (
@@ -144,6 +153,13 @@ def forward(
144
153
tokens : Optional [List [List [str ]]] = None ,
145
154
languages : Optional [List [str ]] = None ,
146
155
):
156
+ if texts is not None and tokens is not None :
157
+ raise RuntimeError ("Can't set both tokens and texts" )
158
+ if self .tokenizer is not None and texts is not None :
159
+ tokens = [
160
+ [t [0 ] for t in self .tokenizer .tokenize (text )] for text in texts
161
+ ]
162
+
147
163
if tokens is None :
148
164
raise RuntimeError ("tokens is required" )
149
165
@@ -167,6 +183,7 @@ def __init__(self):
167
183
self .output_layer = output_layer
168
184
self .pad_idx = jit .Attribute (input_vocab .get_pad_index (), int )
169
185
self .max_seq_len = jit .Attribute (max_seq_len , int )
186
+ self .tokenizer = scripted_tokenizer
170
187
171
188
@jit .script_method
172
189
def forward (
@@ -177,6 +194,13 @@ def forward(
177
194
languages : Optional [List [str ]] = None ,
178
195
dense_feat : Optional [List [List [float ]]] = None ,
179
196
):
197
+ if texts is not None and tokens is not None :
198
+ raise RuntimeError ("Can't set both tokens and texts" )
199
+ if self .tokenizer is not None and texts is not None :
200
+ tokens = [
201
+ [t [0 ] for t in self .tokenizer .tokenize (text )] for text in texts
202
+ ]
203
+
180
204
if tokens is None :
181
205
raise RuntimeError ("tokens is required" )
182
206
if dense_feat is None :
0 commit comments