Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit 5ab7ded

Browse files
shreydesaifacebook-github-bot
authored andcommitted
enabled lmlstm caffe2 exporting (#766)
Summary: Pull Request resolved: #766 Enables caffe2 exporting for LMLSTM Reviewed By: seayoung1112 Differential Revision: D16190189 fbshipit-source-id: 8e04dfd987755d8c1f6ab52e476b97d5b5d596ab
1 parent fcaa931 commit 5ab7ded

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

pytext/models/language_models/lmlstm.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pytext.config import ConfigBase
88
from pytext.data import CommonMetadata
99
from pytext.data.tensorizers import Tensorizer, TokenTensorizer
10+
from pytext.exporters.exporter import ModelExporter
1011
from pytext.models.decoders import DecoderBase
1112
from pytext.models.decoders.mlp_decoder import MLPDecoder
1213
from pytext.models.embeddings import EmbeddingBase
@@ -217,6 +218,11 @@ def __init__(
217218
self.module_list = [embedding, representation, decoder]
218219
self._states: Optional[Tuple] = None
219220

221+
def cpu(self):
222+
if self.stateful:
223+
self._states = (self._states[0].cpu(), self._states[1].cpu())
224+
return self._apply(lambda t: t.cpu())
225+
220226
def arrange_model_inputs(self, tensor_dict):
221227
tokens, seq_lens, _ = tensor_dict["tokens"]
222228
# Omit last token because it won't have a corresponding target
@@ -236,6 +242,16 @@ def get_export_output_names(self, tensorizers):
236242
def vocab_to_export(self, tensorizers):
237243
return {"tokens_vals": list(tensorizers["tokens"].vocab)}
238244

245+
def caffe2_export(self, tensorizers, tensor_dict, path, export_onnx_path=None):
246+
exporter = ModelExporter(
247+
ModelExporter.Config(),
248+
self.get_export_input_names(tensorizers),
249+
self.arrange_model_inputs(tensor_dict),
250+
self.vocab_to_export(tensorizers),
251+
self.get_export_output_names(tensorizers),
252+
)
253+
return exporter.export_to_caffe2(self, path, export_onnx_path=export_onnx_path)
254+
239255
def forward(
240256
self, tokens: torch.Tensor, seq_len: torch.Tensor
241257
) -> List[torch.Tensor]:

0 commit comments

Comments
 (0)