Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Add support for torchscriptification of XLM intent slot models #1167

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions pytext/models/output_layers/intent_slot_output_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,18 @@ def __init__(self, doc_scores: jit.ScriptModule, word_scores: jit.ScriptModule):
def forward(
self,
logits: Tuple[torch.Tensor, torch.Tensor],
seq_lengths: torch.Tensor,
token_indices: Optional[torch.Tensor] = None,
context: Dict[str, torch.Tensor],
) -> Tuple[List[Dict[str, float]], List[List[Dict[str, float]]]]:
d_logits, w_logits = logits
if token_indices is not None:
if "token_indices" in context:
w_logits = torch.gather(
w_logits,
1,
token_indices.unsqueeze(2).expand(-1, -1, w_logits.size(-1)),
context["token_indices"].unsqueeze(2).expand(-1, -1, w_logits.size(-1)),
)

d_results = self.doc_scores(d_logits)
w_results = self.word_scores(w_logits, seq_lengths)
w_results = self.word_scores(w_logits, context)
return d_results, w_results


Expand Down
8 changes: 5 additions & 3 deletions pytext/models/output_layers/word_tagging_output_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, classes):
self.classes = classes

def forward(
self, logits: torch.Tensor, seq_lengths: Optional[torch.Tensor] = None
self, logits: torch.Tensor, context: Optional[Dict[str, torch.Tensor]] = None
) -> List[List[Dict[str, float]]]:
scores: torch.Tensor = F.log_softmax(logits, 2)
return _get_prediction_from_scores(scores, self.classes)
Expand All @@ -48,9 +48,11 @@ def __init__(self, classes: List[str], crf):
self.crf.eval()

def forward(
self, logits: torch.Tensor, seq_lengths: torch.Tensor
self, logits: torch.Tensor, context: Dict[str, torch.Tensor]
) -> List[List[Dict[str, float]]]:
pred = self.crf.decode(logits, seq_lengths)
# We need seq_lengths for CRF decode
assert "seq_lens" in context
pred = self.crf.decode(logits, context["seq_lens"])
logits_rearranged = _rearrange_output(logits, pred)
scores: torch.Tensor = F.log_softmax(logits_rearranged, 2)
return _get_prediction_from_scores(scores, self.classes)
Expand Down
10 changes: 4 additions & 6 deletions pytext/models/test/output_layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ def test_torchscript_word_tagging_output_layer(self, num_labels, seq_lens):

self._validate_word_tagging_result(
word_layer.get_pred(logits, None, context)[1],
torchsript_word_layer(logits, seq_lens_tensor),
torchsript_word_layer(logits, context),
vocab,
)
self._validate_word_tagging_result(
crf_layer.get_pred(logits, None, context)[1],
torchscript_crf_layer(logits, seq_lens_tensor),
torchscript_crf_layer(logits, context),
vocab,
)

Expand Down Expand Up @@ -103,7 +103,7 @@ def test_torchscript_intent_slot_output_layer(
pt_output = intent_slot_output_layer.get_pred(
(doc_logits, word_logits), None, context
)[1]
ts_output = torchscript_output_layer((doc_logits, word_logits), seq_lens_tensor)
ts_output = torchscript_output_layer((doc_logits, word_logits), context)

self._validate_doc_classification_result(pt_output[0], ts_output[0], doc_vocab)
self._validate_word_tagging_result(pt_output[1], ts_output[1], word_vocab)
Expand All @@ -119,9 +119,7 @@ def test_torchscript_intent_slot_output_layer(
pt_output = intent_slot_output_layer.get_pred(
(doc_logits, word_bpe_logits), None, context
)[1]
ts_output = torchscript_output_layer(
(doc_logits, word_bpe_logits), seq_lens_tensor, token_indices_tensor
)
ts_output = torchscript_output_layer((doc_logits, word_bpe_logits), context)

self._validate_doc_classification_result(pt_output[0], ts_output[0], doc_vocab)
self._validate_word_tagging_result(pt_output[1], ts_output[1], word_vocab)
Expand Down