Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit e7694cd

Browse files
geof90facebook-github-bot
authored andcommitted
Support bytes input in word tagging model OSS (#745)
Summary: Pull Request resolved: #745 As title Reviewed By: bethebunny Differential Revision: D16078832 fbshipit-source-id: 7625dd9268acb82cae401da594ba235d15cff3c0
1 parent 04e9be7 commit e7694cd

File tree

1 file changed

+63
-2
lines changed

1 file changed

+63
-2
lines changed

pytext/models/word_model.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,15 @@
33

44
from typing import Union
55

6-
from pytext.data.tensorizers import SlotLabelTensorizer, TokenTensorizer
6+
from pytext.data.tensorizers import (
7+
ByteTokenTensorizer,
8+
SlotLabelTensorizer,
9+
TokenTensorizer,
10+
)
711
from pytext.data.utils import UNK
12+
from pytext.exporters.exporter import ModelExporter
813
from pytext.models.decoders.mlp_decoder import MLPDecoder
9-
from pytext.models.embeddings import WordEmbedding
14+
from pytext.models.embeddings import CharacterEmbedding, WordEmbedding
1015
from pytext.models.model import Model
1116
from pytext.models.module import create_module
1217
from pytext.models.output_layers import CRFOutputLayer, WordTaggingOutputLayer
@@ -52,6 +57,9 @@ def __init__(self, *args, **kwargs):
5257

5358

5459
class WordTaggingModel(Model):
60+
61+
__EXPANSIBLE__ = True
62+
5563
class Config(Model.Config):
5664
class ModelInput(Model.Config.ModelInput):
5765
tokens: TokenTensorizer.Config = TokenTensorizer.Config()
@@ -108,3 +116,56 @@ def arrange_model_inputs(self, tensor_dict):
108116

109117
def arrange_targets(self, tensor_dict):
110118
return tensor_dict["labels"]
119+
120+
def get_export_input_names(self, tensorizers):
121+
return "tokens", "tokens_lens"
122+
123+
def get_export_output_names(self, tensorizers):
124+
return ["word_scores"]
125+
126+
def vocab_to_export(self, tensorizers):
127+
return {"tokens": list(tensorizers["tokens"].vocab)}
128+
129+
def caffe2_export(self, tensorizers, tensor_dict, path, export_onnx_path=None):
130+
exporter = ModelExporter(
131+
ModelExporter.Config(),
132+
self.get_export_input_names(tensorizers),
133+
self.arrange_model_inputs(tensor_dict),
134+
self.vocab_to_export(tensorizers),
135+
self.get_export_output_names(tensorizers),
136+
)
137+
return exporter.export_to_caffe2(self, path, export_onnx_path=export_onnx_path)
138+
139+
140+
class WordTaggingLiteModel(WordTaggingModel):
141+
"""
142+
Also a word tagging model, but uses bytes as inputs to the model. Using
143+
bytes instead of words, the model does not need to store a word embedding
144+
table mapping words in the vocab to their embedding vector representations,
145+
but instead compute them on the fly using CharacterEmbedding. This produces
146+
an exported/serialized model that requires much less storage space as well
147+
as less memory during run/inference time.
148+
"""
149+
150+
class Config(WordTaggingModel.Config):
151+
class ModelInput(WordTaggingModel.Config.ModelInput):
152+
# We should support characters as well, but CharacterTokenTensorizer
153+
# does not support adding characters to vocab yet.
154+
tokens: ByteTokenTensorizer.Config = ByteTokenTensorizer.Config()
155+
156+
inputs: ModelInput = ModelInput()
157+
embedding: CharacterEmbedding.Config = CharacterEmbedding.Config()
158+
159+
@classmethod
160+
def create_embedding(cls, config, tensorizers):
161+
return CharacterEmbedding(
162+
tensorizers["tokens"].NUM_BYTES,
163+
config.embedding.embed_dim,
164+
config.embedding.cnn.kernel_num,
165+
config.embedding.cnn.kernel_sizes,
166+
config.embedding.highway_layers,
167+
config.embedding.projection_dim,
168+
)
169+
170+
def vocab_to_export(self, tensorizers):
171+
return {}

0 commit comments

Comments
 (0)