3
3
4
4
from typing import Union
5
5
6
- from pytext .data .tensorizers import SlotLabelTensorizer , TokenTensorizer
6
+ from pytext .data .tensorizers import (
7
+ ByteTokenTensorizer ,
8
+ SlotLabelTensorizer ,
9
+ TokenTensorizer ,
10
+ )
7
11
from pytext .data .utils import UNK
12
+ from pytext .exporters .exporter import ModelExporter
8
13
from pytext .models .decoders .mlp_decoder import MLPDecoder
9
- from pytext .models .embeddings import WordEmbedding
14
+ from pytext .models .embeddings import CharacterEmbedding , WordEmbedding
10
15
from pytext .models .model import Model
11
16
from pytext .models .module import create_module
12
17
from pytext .models .output_layers import CRFOutputLayer , WordTaggingOutputLayer
@@ -54,11 +59,17 @@ def __init__(self, *args, **kwargs):
54
59
class WordTaggingModel (Model ):
55
60
class Config (Model .Config ):
56
61
class ModelInput (Model .Config .ModelInput ):
57
- tokens : TokenTensorizer .Config = TokenTensorizer .Config ()
62
+ # We should support characters as well, but CharacterTokenTensorizer
63
+ # does not support adding characters to vocab yet.
64
+ tokens : Union [
65
+ ByteTokenTensorizer .Config , TokenTensorizer .Config
66
+ ] = TokenTensorizer .Config ()
58
67
labels : SlotLabelTensorizer .Config = SlotLabelTensorizer .Config ()
59
68
60
69
inputs : ModelInput = ModelInput ()
61
- embedding : WordEmbedding .Config = WordEmbedding .Config ()
70
+ embedding : Union [
71
+ WordEmbedding .Config , CharacterEmbedding .Config
72
+ ] = WordEmbedding .Config ()
62
73
63
74
representation : Union [
64
75
BiLSTMSlotAttention .Config , # TODO: make default when sorting solved
@@ -72,10 +83,21 @@ class ModelInput(Model.Config.ModelInput):
72
83
73
84
@classmethod
74
85
def create_embedding (cls , config , tensorizers ):
75
- vocab = tensorizers ["tokens" ].vocab
76
- return WordEmbedding (
77
- len (vocab ), config .embedding .embed_dim , None , None , vocab .idx [UNK ], []
78
- )
86
+ token_tensorizer = tensorizers ["tokens" ]
87
+ if isinstance (token_tensorizer , TokenTensorizer ):
88
+ vocab = token_tensorizer .vocab
89
+ return WordEmbedding (
90
+ len (vocab ), config .embedding .embed_dim , None , None , vocab .idx [UNK ], []
91
+ )
92
+ else :
93
+ return CharacterEmbedding (
94
+ token_tensorizer .NUM_BYTES ,
95
+ config .embedding .embed_dim ,
96
+ config .embedding .cnn .kernel_num ,
97
+ config .embedding .cnn .kernel_sizes ,
98
+ config .embedding .highway_layers ,
99
+ config .embedding .projection_dim ,
100
+ )
79
101
80
102
@classmethod
81
103
def from_config (cls , config , tensorizers ):
@@ -108,3 +130,26 @@ def arrange_model_inputs(self, tensor_dict):
108
130
109
131
def arrange_targets (self , tensor_dict ):
110
132
return tensor_dict ["labels" ]
133
+
134
+ def get_export_input_names (self , tensorizers ):
135
+ return "tokens" , "tokens_lens"
136
+
137
+ def get_export_output_names (self , tensorizers ):
138
+ return ["word_scores" ]
139
+
140
+ def vocab_to_export (self , tensorizers ):
141
+ token_tensorizer = tensorizers ["tokens" ]
142
+ if isinstance (token_tensorizer , TokenTensorizer ):
143
+ return {"tokens" : list (token_tensorizer .vocab )}
144
+
145
+ return {}
146
+
147
+ def caffe2_export (self , tensorizers , tensor_dict , path , export_onnx_path = None ):
148
+ exporter = ModelExporter (
149
+ ModelExporter .Config (),
150
+ self .get_export_input_names (tensorizers ),
151
+ self .arrange_model_inputs (tensor_dict ),
152
+ self .vocab_to_export (tensorizers ),
153
+ self .get_export_output_names (tensorizers ),
154
+ )
155
+ return exporter .export_to_caffe2 (self , path , export_onnx_path = export_onnx_path )
0 commit comments