7
7
from pytext .config import ConfigBase
8
8
from pytext .data import CommonMetadata
9
9
from pytext .data .tensorizers import Tensorizer , TokenTensorizer
10
+ from pytext .exporters .exporter import ModelExporter
10
11
from pytext .models .decoders import DecoderBase
11
12
from pytext .models .decoders .mlp_decoder import MLPDecoder
12
13
from pytext .models .embeddings import EmbeddingBase
@@ -217,6 +218,11 @@ def __init__(
217
218
self .module_list = [embedding , representation , decoder ]
218
219
self ._states : Optional [Tuple ] = None
219
220
221
+ def cpu (self ):
222
+ if self .stateful :
223
+ self ._states = (self ._states [0 ].cpu (), self ._states [1 ].cpu ())
224
+ return self ._apply (lambda t : t .cpu ())
225
+
220
226
def arrange_model_inputs (self , tensor_dict ):
221
227
tokens , seq_lens , _ = tensor_dict ["tokens" ]
222
228
# Omit last token because it won't have a corresponding target
@@ -236,6 +242,16 @@ def get_export_output_names(self, tensorizers):
236
242
def vocab_to_export (self , tensorizers ):
237
243
return {"tokens_vals" : list (tensorizers ["tokens" ].vocab )}
238
244
245
+ def caffe2_export (self , tensorizers , tensor_dict , path , export_onnx_path = None ):
246
+ exporter = ModelExporter (
247
+ ModelExporter .Config (),
248
+ self .get_export_input_names (tensorizers ),
249
+ self .arrange_model_inputs (tensor_dict ),
250
+ self .vocab_to_export (tensorizers ),
251
+ self .get_export_output_names (tensorizers ),
252
+ )
253
+ return exporter .export_to_caffe2 (self , path , export_onnx_path = export_onnx_path )
254
+
239
255
def forward (
240
256
self , tokens : torch .Tensor , seq_len : torch .Tensor
241
257
) -> List [torch .Tensor ]:
0 commit comments