@@ -109,6 +109,14 @@ def torchscriptify(self, tensorizers, traced_model):
109
109
output_layer = self .output_layer .torchscript_predictions ()
110
110
111
111
input_vocab = tensorizers ["tokens" ].vocab
112
+ max_seq_len = tensorizers ["tokens" ].max_seq_len or - 1
113
+
114
+ """
115
+ The input tensor packing memory is allocated/cached for different shapes,
116
+ and max sequence length will help to reduce the number of different tensor
117
+ shapes. We noticed that the TorchScript model could use 25G for offline
118
+ inference on CPU without using max_seq_len.
119
+ """
112
120
113
121
class Model (jit .ScriptModule ):
114
122
def __init__ (self ):
@@ -117,6 +125,7 @@ def __init__(self):
117
125
self .model = traced_model
118
126
self .output_layer = output_layer
119
127
self .pad_idx = jit .Attribute (input_vocab .idx [PAD ], int )
128
+ self .max_seq_len = jit .Attribute (max_seq_len , int )
120
129
121
130
@jit .script_method
122
131
def forward (
@@ -128,8 +137,15 @@ def forward(
128
137
if tokens is None :
129
138
raise RuntimeError ("tokens is required" )
130
139
131
- seq_lens = make_sequence_lengths (tokens )
132
- word_ids = self .vocab .lookup_indices_2d (tokens )
140
+ trimmed_tokens : List [List [str ]] = []
141
+ if self .max_seq_len >= 0 :
142
+ for token in tokens :
143
+ trimmed_tokens .append (token [0 : self .max_seq_len ])
144
+ else :
145
+ trimmed_tokens = tokens
146
+
147
+ seq_lens = make_sequence_lengths (trimmed_tokens )
148
+ word_ids = self .vocab .lookup_indices_2d (trimmed_tokens )
133
149
word_ids = pad_2d (word_ids , seq_lens , self .pad_idx )
134
150
logits = self .model (torch .tensor (word_ids ), torch .tensor (seq_lens ))
135
151
return self .output_layer (logits )
@@ -142,6 +158,7 @@ def __init__(self):
142
158
self .model = traced_model
143
159
self .output_layer = output_layer
144
160
self .pad_idx = jit .Attribute (input_vocab .idx [PAD ], int )
161
+ self .max_seq_len = jit .Attribute (max_seq_len , int )
145
162
146
163
@jit .script_method
147
164
def forward (
@@ -156,8 +173,15 @@ def forward(
156
173
if dense_feat is None :
157
174
raise RuntimeError ("dense_feat is required" )
158
175
159
- seq_lens = make_sequence_lengths (tokens )
160
- word_ids = self .vocab .lookup_indices_2d (tokens )
176
+ trimmed_tokens : List [List [str ]] = []
177
+ if self .max_seq_len >= 0 :
178
+ for token in tokens :
179
+ trimmed_tokens .append (token [0 : self .max_seq_len ])
180
+ else :
181
+ trimmed_tokens = tokens
182
+
183
+ seq_lens = make_sequence_lengths (trimmed_tokens )
184
+ word_ids = self .vocab .lookup_indices_2d (trimmed_tokens )
161
185
word_ids = pad_2d (word_ids , seq_lens , self .pad_idx )
162
186
dense_feat = self .normalizer .normalize (dense_feat )
163
187
logits = self .model (
0 commit comments