Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit 07c2d7f

Browse files
Michael Wufacebook-github-bot
authored andcommitted
Enable dense features in ByteTokensDocumentModel (#763)
Summary: Pull Request resolved: #763 Update a couple model methods to support dense features, as the parent model `DocModel` already supports dense features. Differential Revision: D16187410 fbshipit-source-id: b7fc7f4e753a4bea4aaa1f3161a454d54c791958
1 parent 2a8f365 commit 07c2d7f

File tree

1 file changed

+40
-4
lines changed

1 file changed

+40
-4
lines changed

pytext/models/doc_model.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,9 @@ def create_embedding(cls, config, tensorizers: Dict[str, Tensorizer]):
225225
assert word_tensorizer.column == byte_tensorizer.column
226226

227227
word_embedding = create_module(
228-
config.embedding, tensorizer=tensorizers["tokens"]
228+
config.embedding,
229+
tensorizer=tensorizers["tokens"],
230+
init_from_saved_state=config.init_from_saved_state,
229231
)
230232
byte_embedding = CharacterEmbedding(
231233
ByteTokenTensorizer.NUM_BYTES,
@@ -241,10 +243,16 @@ def arrange_model_inputs(self, tensor_dict):
241243
tokens, seq_lens, _ = tensor_dict["tokens"]
242244
token_bytes, byte_seq_lens, _ = tensor_dict["token_bytes"]
243245
assert (seq_lens == byte_seq_lens).all().item()
244-
return tokens, token_bytes, seq_lens
246+
model_inputs = tokens, token_bytes, seq_lens
247+
if "dense" in tensor_dict:
248+
model_inputs += (tensor_dict["dense"],)
249+
return model_inputs
245250

246251
def get_export_input_names(self, tensorizers):
247-
return ["tokens", "token_bytes", "tokens_lens"]
252+
names = ["tokens", "token_bytes", "tokens_lens"]
253+
if "dense" in tensorizers:
254+
names.append("float_vec_vals")
255+
return names
248256

249257
def torchscriptify(self, tensorizers, traced_model):
250258
output_layer = self.output_layer.torchscript_predictions()
@@ -277,7 +285,35 @@ def forward(self, tokens: List[List[str]]):
277285
)
278286
return self.output_layer(logits)
279287

280-
return Model()
288+
class ModelWithDenseFeat(jit.ScriptModule):
289+
def __init__(self):
290+
super().__init__()
291+
self.vocab = Vocabulary(input_vocab, unk_idx=input_vocab.idx[UNK])
292+
self.max_byte_len = jit.Attribute(max_byte_len, int)
293+
self.byte_offset_for_non_padding = jit.Attribute(
294+
byte_offset_for_non_padding, int
295+
)
296+
self.pad_idx = jit.Attribute(input_vocab.idx[PAD], int)
297+
self.model = traced_model
298+
self.output_layer = output_layer
299+
300+
@jit.script_method
301+
def forward(self, tokens: List[List[str]], dense_feat: List[List[float]]):
302+
seq_lens = make_sequence_lengths(tokens)
303+
word_ids = self.vocab.lookup_indices_2d(tokens)
304+
word_ids = pad_2d(word_ids, seq_lens, self.pad_idx)
305+
token_bytes, _ = make_byte_inputs(
306+
tokens, self.max_byte_len, self.byte_offset_for_non_padding
307+
)
308+
logits = self.model(
309+
torch.tensor(word_ids),
310+
token_bytes,
311+
torch.tensor(seq_lens),
312+
torch.tensor(dense_feat),
313+
)
314+
return self.output_layer(logits)
315+
316+
return ModelWithDenseFeat() if "dense" in tensorizers else Model()
281317

282318

283319
class DocRegressionModel(DocModel):

0 commit comments

Comments
 (0)