Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit 014fb4c

Browse files
chenyangyu1988facebook-github-bot
authored andcommitted
add max_seq_len to DocNN TorchScript model (#1279)
Summary: Pull Request resolved: #1279 This could dramatically reduce the memory usage for DocNN TorchScript model. Quick experiment: https://our.intern.facebook.com/intern/anp/view/?id=215724 https://fb.workplace.com/groups/1941258842562334/permalink/3002460646442143/ Reviewed By: m3rlin45 Differential Revision: D20409424 fbshipit-source-id: 1794e2c687c2463b98f1c62cd6842cb1a1b8cda6
1 parent 495e455 commit 014fb4c

File tree

1 file changed

+28
-4
lines changed

1 file changed

+28
-4
lines changed

pytext/models/doc_model.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,14 @@ def torchscriptify(self, tensorizers, traced_model):
109109
output_layer = self.output_layer.torchscript_predictions()
110110

111111
input_vocab = tensorizers["tokens"].vocab
112+
max_seq_len = tensorizers["tokens"].max_seq_len or -1
113+
114+
"""
115+
The input tensor packing memory is allocated/cached for different shapes,
116+
and max sequence length will help to reduce the number of different tensor
117+
shapes. We noticed that the TorchScript model could use 25G for offline
118+
inference on CPU without using max_seq_len.
119+
"""
112120

113121
class Model(jit.ScriptModule):
114122
def __init__(self):
@@ -117,6 +125,7 @@ def __init__(self):
117125
self.model = traced_model
118126
self.output_layer = output_layer
119127
self.pad_idx = jit.Attribute(input_vocab.idx[PAD], int)
128+
self.max_seq_len = jit.Attribute(max_seq_len, int)
120129

121130
@jit.script_method
122131
def forward(
@@ -128,8 +137,15 @@ def forward(
128137
if tokens is None:
129138
raise RuntimeError("tokens is required")
130139

131-
seq_lens = make_sequence_lengths(tokens)
132-
word_ids = self.vocab.lookup_indices_2d(tokens)
140+
trimmed_tokens: List[List[str]] = []
141+
if self.max_seq_len >= 0:
142+
for token in tokens:
143+
trimmed_tokens.append(token[0 : self.max_seq_len])
144+
else:
145+
trimmed_tokens = tokens
146+
147+
seq_lens = make_sequence_lengths(trimmed_tokens)
148+
word_ids = self.vocab.lookup_indices_2d(trimmed_tokens)
133149
word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
134150
logits = self.model(torch.tensor(word_ids), torch.tensor(seq_lens))
135151
return self.output_layer(logits)
@@ -142,6 +158,7 @@ def __init__(self):
142158
self.model = traced_model
143159
self.output_layer = output_layer
144160
self.pad_idx = jit.Attribute(input_vocab.idx[PAD], int)
161+
self.max_seq_len = jit.Attribute(max_seq_len, int)
145162

146163
@jit.script_method
147164
def forward(
@@ -156,8 +173,15 @@ def forward(
156173
if dense_feat is None:
157174
raise RuntimeError("dense_feat is required")
158175

159-
seq_lens = make_sequence_lengths(tokens)
160-
word_ids = self.vocab.lookup_indices_2d(tokens)
176+
trimmed_tokens: List[List[str]] = []
177+
if self.max_seq_len >= 0:
178+
for token in tokens:
179+
trimmed_tokens.append(token[0 : self.max_seq_len])
180+
else:
181+
trimmed_tokens = tokens
182+
183+
seq_lens = make_sequence_lengths(trimmed_tokens)
184+
word_ids = self.vocab.lookup_indices_2d(trimmed_tokens)
161185
word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
162186
dense_feat = self.normalizer.normalize(dense_feat)
163187
logits = self.model(

0 commit comments

Comments
 (0)