Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit 1f0c4ea

Browse files
geof90facebook-github-bot
authored andcommitted
Support bytes input in joint intent-slot model OSS (#745)
Summary: Pull Request resolved: #745 As title Differential Revision: D16078832 fbshipit-source-id: e10d4cb5c01ae7ba71d72ba0edc0af54ac0190db
1 parent 04e9be7 commit 1f0c4ea

File tree

1 file changed

+53
-8
lines changed

1 file changed

+53
-8
lines changed

pytext/models/word_model.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,15 @@
33

44
from typing import Union
55

6-
from pytext.data.tensorizers import SlotLabelTensorizer, TokenTensorizer
6+
from pytext.data.tensorizers import (
7+
ByteTokenTensorizer,
8+
SlotLabelTensorizer,
9+
TokenTensorizer,
10+
)
711
from pytext.data.utils import UNK
12+
from pytext.exporters.exporter import ModelExporter
813
from pytext.models.decoders.mlp_decoder import MLPDecoder
9-
from pytext.models.embeddings import WordEmbedding
14+
from pytext.models.embeddings import CharacterEmbedding, WordEmbedding
1015
from pytext.models.model import Model
1116
from pytext.models.module import create_module
1217
from pytext.models.output_layers import CRFOutputLayer, WordTaggingOutputLayer
@@ -54,11 +59,17 @@ def __init__(self, *args, **kwargs):
5459
class WordTaggingModel(Model):
5560
class Config(Model.Config):
5661
class ModelInput(Model.Config.ModelInput):
57-
tokens: TokenTensorizer.Config = TokenTensorizer.Config()
62+
# We should support characters as well, but CharacterTokenTensorizer
63+
# does not support adding characters to vocab yet.
64+
tokens: Union[
65+
ByteTokenTensorizer.Config, TokenTensorizer.Config
66+
] = TokenTensorizer.Config()
5867
labels: SlotLabelTensorizer.Config = SlotLabelTensorizer.Config()
5968

6069
inputs: ModelInput = ModelInput()
61-
embedding: WordEmbedding.Config = WordEmbedding.Config()
70+
embedding: Union[
71+
WordEmbedding.Config, CharacterEmbedding.Config
72+
] = WordEmbedding.Config()
6273

6374
representation: Union[
6475
BiLSTMSlotAttention.Config, # TODO: make default when sorting solved
@@ -72,10 +83,21 @@ class ModelInput(Model.Config.ModelInput):
7283

7384
@classmethod
7485
def create_embedding(cls, config, tensorizers):
75-
vocab = tensorizers["tokens"].vocab
76-
return WordEmbedding(
77-
len(vocab), config.embedding.embed_dim, None, None, vocab.idx[UNK], []
78-
)
86+
token_tensorizer = tensorizers["tokens"]
87+
if isinstance(token_tensorizer, TokenTensorizer):
88+
vocab = token_tensorizer.vocab
89+
return WordEmbedding(
90+
len(vocab), config.embedding.embed_dim, None, None, vocab.idx[UNK], []
91+
)
92+
else:
93+
return CharacterEmbedding(
94+
token_tensorizer.NUM_BYTES,
95+
config.embedding.embed_dim,
96+
config.embedding.cnn.kernel_num,
97+
config.embedding.cnn.kernel_sizes,
98+
config.embedding.highway_layers,
99+
config.embedding.projection_dim,
100+
)
79101

80102
@classmethod
81103
def from_config(cls, config, tensorizers):
@@ -108,3 +130,26 @@ def arrange_model_inputs(self, tensor_dict):
108130

109131
def arrange_targets(self, tensor_dict):
110132
return tensor_dict["labels"]
133+
134+
def get_export_input_names(self, tensorizers):
135+
return "tokens", "tokens_lens"
136+
137+
def get_export_output_names(self, tensorizers):
138+
return ["word_scores"]
139+
140+
def vocab_to_export(self, tensorizers):
141+
token_tensorizer = tensorizers["tokens"]
142+
if isinstance(token_tensorizer, TokenTensorizer):
143+
return {"tokens": list(token_tensorizer.vocab)}
144+
145+
return {}
146+
147+
def caffe2_export(self, tensorizers, tensor_dict, path, export_onnx_path=None):
148+
exporter = ModelExporter(
149+
ModelExporter.Config(),
150+
self.get_export_input_names(tensorizers),
151+
self.arrange_model_inputs(tensor_dict),
152+
self.vocab_to_export(tensorizers),
153+
self.get_export_output_names(tensorizers),
154+
)
155+
return exporter.export_to_caffe2(self, path, export_onnx_path=export_onnx_path)

0 commit comments

Comments
 (0)