Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit 898c41c

Browse files
hudevenfacebook-github-bot
authored andcommitted
Unify input for TorchScript Tensorizers and Models (#1256)
Summary: Pull Request resolved: #1256 ## Dataflow from PyText client to TorchScript model in predictor 1. Client sends optional "texts", "tokens", "languages", "dense_feat" args in predictor request 2. Predictor pass them to forward() of TorchScript Module(Tensorizer + Model) 3. In ScriptPyTextModule(texts, tokens, languages, dense_feat), the args are converted to ScriptBatchInput(NamedTuple) => Tensorizer.forward() => tuple of Tensors => Model.forward() ## Before: We need a wrapper for each combination of (texts, tokens, languages, dense) ## After: a wrapper for with dense, another wrapper for without dense feature ## Alternative: ScriptPyTextModule(inputs: ScriptBatchInput) after NamedTuple is supported in client example => predictor => ScriptPyTextModule https://fb.workplace.com/groups/811605488888068/permalink/3266598560055403/ Reviewed By: snisarg, chenyangyu1988 Differential Revision: D19900062 fbshipit-source-id: 2f3884d8d93c5b4d67b78033e9fc92b5b7b6e2fe
1 parent 37718f0 commit 898c41c

File tree

10 files changed

+85
-146
lines changed

10 files changed

+85
-146
lines changed

pytext/data/bert_tensorizer.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from pytext.data.tokenizers import Tokenizer, WordPieceTokenizer
1212
from pytext.data.utils import BOS, EOS, MASK, PAD, UNK, SpecialToken, Vocabulary
1313
from pytext.torchscript.tensorizer.tensorizer import VocabLookup
14-
from pytext.torchscript.utils import pad_2d, pad_2d_mask
14+
from pytext.torchscript.utils import ScriptBatchInput, pad_2d, pad_2d_mask
1515
from pytext.torchscript.vocab import ScriptVocabulary
1616
from pytext.utils.file_io import PathManager
1717
from pytext.utils.lazy import lazy_property
@@ -207,9 +207,7 @@ def tokenize(
207207
return per_sentence_tokens
208208

209209
def forward(
210-
self,
211-
texts: Optional[List[List[str]]] = None,
212-
pre_tokenized: Optional[List[List[List[str]]]] = None,
210+
self, inputs: ScriptBatchInput
213211
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
214212
"""
215213
Wire up tokenize(), numberize() and tensorize() functions for data
@@ -223,10 +221,10 @@ def forward(
223221
seq_lens_1d: List[int] = []
224222
positions_2d: List[List[int]] = []
225223

226-
for idx in range(self.batch_size(texts, pre_tokenized)):
224+
for idx in range(self.batch_size(inputs)):
227225
tokens: List[List[Tuple[str, int, int]]] = self.tokenize(
228-
self.get_texts_by_index(texts, idx),
229-
self.get_tokens_by_index(pre_tokenized, idx),
226+
self.get_texts_by_index(inputs.texts, idx),
227+
self.get_tokens_by_index(inputs.tokens, idx),
230228
)
231229

232230
numberized: Tuple[List[int], List[int], int, List[int]] = self.numberize(

pytext/data/tensorizers.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pytext.data.sources.data_source import Gazetteer
2222
from pytext.data.tokenizers import Token, Tokenizer
2323
from pytext.torchscript.tensorizer import VectorNormalizer
24+
from pytext.torchscript.utils import ScriptBatchInput
2425
from pytext.utils import cuda, precision
2526
from pytext.utils.data import Slot
2627
from pytext.utils.file_io import PathManager
@@ -116,19 +117,19 @@ def __init__(self):
116117
def set_device(self, device: str):
117118
self.device = device
118119

119-
def batch_size(
120-
self, texts: Optional[List[List[str]]], tokens: Optional[List[List[List[str]]]]
121-
) -> int:
120+
def batch_size(self, inputs: ScriptBatchInput) -> int:
121+
texts: Optional[List[List[str]]] = inputs.texts
122+
tokens: Optional[List[List[List[str]]]] = inputs.tokens
122123
if texts is not None:
123124
return len(texts)
124125
elif tokens is not None:
125126
return len(tokens)
126127
else:
127128
raise RuntimeError("Empty input for both texts and tokens.")
128129

129-
def row_size(
130-
self, texts: Optional[List[List[str]]], tokens: Optional[List[List[List[str]]]]
131-
) -> int:
130+
def row_size(self, inputs: ScriptBatchInput) -> int:
131+
texts: Optional[List[List[str]]] = inputs.texts
132+
tokens: Optional[List[List[List[str]]]] = inputs.tokens
132133
if texts is not None:
133134
return len(texts[0])
134135
elif tokens is not None:
@@ -139,14 +140,14 @@ def row_size(
139140
def get_texts_by_index(
140141
self, texts: Optional[List[List[str]]], index: int
141142
) -> Optional[List[str]]:
142-
if texts is None:
143+
if texts is None or len(texts) == 0:
143144
return None
144145
return texts[index]
145146

146147
def get_tokens_by_index(
147148
self, tokens: Optional[List[List[List[str]]]], index: int
148149
) -> Optional[List[List[str]]]:
149-
if tokens is None:
150+
if tokens is None or len(tokens) == 0:
150151
return None
151152
return tokens[index]
152153

pytext/data/xlm_tensorizer.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pytext.data.tokenizers import Tokenizer
1515
from pytext.data.utils import EOS, MASK, PAD, UNK, Vocabulary
1616
from pytext.data.xlm_constants import LANG2ID_15
17+
from pytext.torchscript.utils import ScriptBatchInput
1718
from pytext.torchscript.vocab import ScriptVocabulary
1819
from pytext.utils.file_io import PathManager
1920
from pytext.utils.lazy import lazy_property
@@ -71,17 +72,15 @@ def numberize(
7172
return tokens, segment_labels, seq_len, positions
7273

7374
def forward(
74-
self,
75-
texts: Optional[List[List[str]]] = None,
76-
pre_tokenized: Optional[List[List[List[str]]]] = None,
77-
languages: Optional[List[List[str]]] = None,
75+
self, inputs: ScriptBatchInput
7876
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
7977
"""
8078
Wire up tokenize(), numberize() and tensorize() functions for data
8179
processing.
8280
"""
83-
batch_size: int = self.batch_size(texts, pre_tokenized)
84-
row_size: int = self.row_size(texts, pre_tokenized)
81+
batch_size: int = self.batch_size(inputs)
82+
row_size: int = self.row_size(inputs)
83+
languages: Optional[List[List[str]]] = inputs.languages
8584
if languages is None:
8685
languages = [[self.default_language] * row_size] * batch_size
8786

@@ -90,10 +89,10 @@ def forward(
9089
seq_lens_1d: List[int] = []
9190
positions_2d: List[List[int]] = []
9291

93-
for idx in range(self.batch_size(texts, pre_tokenized)):
92+
for idx in range(self.batch_size(inputs)):
9493
tokens: List[List[Tuple[str, int, int]]] = self.tokenize(
95-
self.get_texts_by_index(texts, idx),
96-
self.get_tokens_by_index(pre_tokenized, idx),
94+
self.get_texts_by_index(inputs.texts, idx),
95+
self.get_tokens_by_index(inputs.tokens, idx),
9796
)
9897
language_ids: List[int] = [
9998
self.language_vocab.idx.get(

pytext/models/doc_model.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,15 @@ def __init__(self):
119119
self.pad_idx = jit.Attribute(input_vocab.idx[PAD], int)
120120

121121
@jit.script_method
122-
def forward(self, tokens: List[List[str]]):
122+
def forward(
123+
self,
124+
texts: Optional[List[str]] = None,
125+
tokens: Optional[List[List[str]]] = None,
126+
languages: Optional[List[str]] = None,
127+
):
128+
if tokens is None:
129+
raise RuntimeError("tokens is required")
130+
123131
seq_lens = make_sequence_lengths(tokens)
124132
word_ids = self.vocab.lookup_indices_2d(tokens)
125133
word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
@@ -136,11 +144,23 @@ def __init__(self):
136144
self.pad_idx = jit.Attribute(input_vocab.idx[PAD], int)
137145

138146
@jit.script_method
139-
def forward(self, tokens: List[List[str]], dense_feat: List[List[float]]):
147+
def forward(
148+
self,
149+
texts: Optional[List[str]] = None,
150+
tokens: Optional[List[List[str]]] = None,
151+
languages: Optional[List[str]] = None,
152+
dense_feat: Optional[List[List[float]]] = None,
153+
):
154+
if tokens is None:
155+
raise RuntimeError("tokens is required")
156+
140157
seq_lens = make_sequence_lengths(tokens)
141158
word_ids = self.vocab.lookup_indices_2d(tokens)
142159
word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
143-
dense_feat = self.normalizer.normalize(dense_feat)
160+
if dense_feat is not None:
161+
dense_feat = self.normalizer.normalize(dense_feat)
162+
else:
163+
raise RuntimeError("dense is required")
144164
logits = self.model(
145165
torch.tensor(word_ids),
146166
torch.tensor(seq_lens),

pytext/models/roberta.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
33

4-
from typing import Dict, List, Tuple
4+
from typing import Dict, Tuple
55

66
import torch
77
from pytext.common.constants import Stage
@@ -25,7 +25,7 @@
2525
from pytext.models.representations.transformer_sentence_encoder_base import (
2626
TransformerSentenceEncoderBase,
2727
)
28-
from pytext.torchscript.module import get_script_module_cls
28+
from pytext.torchscript.module import ScriptPyTextModule
2929
from pytext.utils.file_io import PathManager
3030
from pytext.utils.usage import log_class_usage
3131
from torch.serialization import default_restore_location
@@ -152,11 +152,7 @@ def torchscriptify(self, tensorizers, traced_model):
152152
values according to the output layer (eg. as a dict mapping class name to score)
153153
"""
154154
script_tensorizer = tensorizers["tokens"].torchscriptify()
155-
script_module_cls = get_script_module_cls(
156-
script_tensorizer.tokenizer.input_type()
157-
)
158-
159-
return script_module_cls(
155+
return ScriptPyTextModule(
160156
model=traced_model,
161157
output_layer=self.output_layer.torchscript_predictions(),
162158
tensorizer=script_tensorizer,

pytext/torchscript/module.py

Lines changed: 19 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,10 @@
11
#!/usr/bin/env python3
22
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3-
43
from typing import List, Optional
54

65
import torch
76
from pytext.torchscript.tensorizer.tensorizer import ScriptTensorizer
8-
from pytext.torchscript.utils import ScriptInputType, squeeze_1d, squeeze_2d
9-
10-
11-
def get_script_module_cls(input_type: ScriptInputType) -> torch.jit.ScriptModule:
12-
if input_type.is_text():
13-
return ScriptTextModule
14-
elif input_type.is_token():
15-
return ScriptTokenModule
16-
else:
17-
raise RuntimeError("Only support text or token input type...")
7+
from pytext.torchscript.utils import ScriptBatchInput, squeeze_1d, squeeze_2d
188

199

2010
class ScriptModule(torch.jit.ScriptModule):
@@ -23,7 +13,7 @@ def set_device(self, device: str):
2313
self.tensorizer.set_device(device)
2414

2515

26-
class ScriptTextModule(ScriptModule):
16+
class ScriptPyTextModule(ScriptModule):
2717
def __init__(
2818
self,
2919
model: torch.jit.ScriptModule,
@@ -36,73 +26,36 @@ def __init__(
3626
self.tensorizer = tensorizer
3727

3828
@torch.jit.script_method
39-
def forward(self, texts: List[str]):
40-
input_tensors = self.tensorizer(texts=squeeze_1d(texts))
41-
logits = self.model(input_tensors)
42-
return self.output_layer(logits)
43-
44-
45-
class ScriptTokenModule(ScriptModule):
46-
def __init__(
47-
self,
48-
model: torch.jit.ScriptModule,
49-
output_layer: torch.jit.ScriptModule,
50-
tensorizer: ScriptTensorizer,
51-
):
52-
super().__init__()
53-
self.model = model
54-
self.output_layer = output_layer
55-
self.tensorizer = tensorizer
56-
57-
@torch.jit.script_method
58-
def forward(self, tokens: List[List[str]]):
59-
input_tensors = self.tensorizer(pre_tokenized=squeeze_2d(tokens))
60-
logits = self.model(input_tensors)
61-
return self.output_layer(logits)
62-
63-
64-
class ScriptTokenLanguageModule(ScriptModule):
65-
def __init__(
29+
def forward(
6630
self,
67-
model: torch.jit.ScriptModule,
68-
output_layer: torch.jit.ScriptModule,
69-
tensorizer: ScriptTensorizer,
31+
texts: Optional[List[str]] = None,
32+
tokens: Optional[List[List[str]]] = None,
33+
languages: Optional[List[str]] = None,
7034
):
71-
super().__init__()
72-
self.model = model
73-
self.output_layer = output_layer
74-
self.tensorizer = tensorizer
75-
76-
@torch.jit.script_method
77-
def forward(self, tokens: List[List[str]], languages: Optional[List[str]] = None):
78-
input_tensors = self.tensorizer(
79-
pre_tokenized=squeeze_2d(tokens), languages=squeeze_1d(languages)
35+
inputs: ScriptBatchInput = ScriptBatchInput(
36+
texts=squeeze_1d(texts),
37+
tokens=squeeze_2d(tokens),
38+
languages=squeeze_1d(languages),
8039
)
40+
input_tensors = self.tensorizer(inputs)
8141
logits = self.model(input_tensors)
8242
return self.output_layer(logits)
8343

8444

85-
class ScriptTokenLanguageModuleWithDenseFeature(ScriptModule):
86-
def __init__(
87-
self,
88-
model: torch.jit.ScriptModule,
89-
output_layer: torch.jit.ScriptModule,
90-
tensorizer: ScriptTensorizer,
91-
):
92-
super().__init__()
93-
self.model = model
94-
self.output_layer = output_layer
95-
self.tensorizer = tensorizer
96-
45+
class ScriptPyTextModuleWithDense(ScriptPyTextModule):
9746
@torch.jit.script_method
9847
def forward(
9948
self,
100-
tokens: List[List[str]],
10149
dense_feat: List[List[float]],
50+
texts: Optional[List[str]] = None,
51+
tokens: Optional[List[List[str]]] = None,
10252
languages: Optional[List[str]] = None,
10353
):
104-
input_tensors = self.tensorizer(
105-
pre_tokenized=squeeze_2d(tokens), languages=squeeze_1d(languages)
54+
inputs: ScriptBatchInput = ScriptBatchInput(
55+
texts=squeeze_1d(texts),
56+
tokens=squeeze_2d(tokens),
57+
languages=squeeze_1d(languages),
10658
)
59+
input_tensors = self.tensorizer(inputs)
10760
logits = self.model(input_tensors, torch.tensor(dense_feat).float())
10861
return self.output_layer(logits)

pytext/torchscript/tests/test_tensorizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
)
1414
from pytext.torchscript.tensorizer.tensorizer import VocabLookup
1515
from pytext.torchscript.tokenizer import ScriptDoNothingTokenizer
16-
from pytext.torchscript.tokenizer.tokenizer import ScriptTextTokenizerBase
16+
from pytext.torchscript.tokenizer.tokenizer import ScriptTokenizerBase
1717
from pytext.torchscript.utils import squeeze_1d, squeeze_2d
1818
from pytext.torchscript.vocab import ScriptVocabulary
1919

@@ -26,7 +26,7 @@ def _mock_vocab(self):
2626
)
2727

2828
def _mock_tokenizer(self):
29-
class MockTokenizer(ScriptTextTokenizerBase):
29+
class MockTokenizer(ScriptTokenizerBase):
3030
def __init__(self, tokens: List[Tuple[str, int, int]]):
3131
super().__init__()
3232
self.tokens = torch.jit.Attribute(tokens, List[Tuple[str, int, int]])

pytext/torchscript/tokenizer/__init__.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,12 @@
22
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
33

44
from .bpe import ScriptBPE
5-
from .tokenizer import (
6-
ScriptBPETokenizer,
7-
ScriptDoNothingTokenizer,
8-
ScriptTextTokenizerBase,
9-
ScriptTokenTokenizerBase,
10-
)
5+
from .tokenizer import ScriptBPETokenizer, ScriptDoNothingTokenizer, ScriptTokenizerBase
116

127

138
__all__ = [
149
"ScriptBPE",
1510
"ScriptBPETokenizer",
1611
"ScriptDoNothingTokenizer",
17-
"ScriptTextTokenizerBase",
18-
"ScriptTokenTokenizerBase",
12+
"ScriptTokenizerBase",
1913
]

0 commit comments

Comments
 (0)