Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit 77753c6

Browse files
chenyangyu1988facebook-github-bot
authored andcommitted
Support embedding from decoder (#1284)
Summary: Pull Request resolved: #1284 support embedding from decoder Reviewed By: mwu1993 Differential Revision: D20515702 fbshipit-source-id: 0101c0048c0935eaef641a8bfa92a9e55afc73d3
1 parent 1a9fa0b commit 77753c6

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

pytext/torchscript/module.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,11 @@ def __init__(
110110
tensorizer: ScriptTensorizer,
111111
normalizer: VectorNormalizer,
112112
index: int = 0,
113+
concat_dense: bool = True,
113114
):
114115
super().__init__(model, tensorizer, index)
115116
self.normalizer = normalizer
117+
self.concat_dense = torch.jit.Attribute(concat_dense, bool)
116118

117119
@torch.jit.script_method
118120
def forward(
@@ -135,5 +137,8 @@ def forward(
135137
dense_feat = self.normalizer.normalize(dense_feat)
136138
dense_tensor = torch.tensor(dense_feat, dtype=torch.float)
137139

138-
encoder_embedding = self.model(input_tensors, dense_tensor)[self.index]
139-
return torch.cat([encoder_embedding, dense_tensor], 1)
140+
sentence_embedding = self.model(input_tensors, dense_tensor)[self.index]
141+
if self.concat_dense:
142+
return torch.cat([sentence_embedding, dense_tensor], 1)
143+
else:
144+
return sentence_embedding

0 commit comments

Comments
 (0)