Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit 6dd8fcc

Browse files
abhinavarorafacebook-github-bot
authored andcommitted
Add support for Torchscript export of IntentSlotOutputLayer and CRF (#1146)
Summary: Pull Request resolved: #1146 This diff does the following: 1. Modifies `IntentSlotOutputLayer`, `WordTaggingOutputLayer` and `CRFOutputLayer` for torchscript export. 1. Makes CRF implementation torchscriptable 1. Fixes `predict` method of `NewTask` to make sure it passes model context as well to `get_pred` 1. Fixes return type of the `forward` method of the decoder to return tuples of tensors instead of lists of tensors. Reviewed By: liaimi Differential Revision: D18565235 fbshipit-source-id: 80836351c96b53f0650fa05ba4b4ab78b866899a
1 parent 7c4cc0b commit 6dd8fcc

File tree

7 files changed

+416
-31
lines changed

7 files changed

+416
-31
lines changed

pytext/models/crf.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
#!/usr/bin/env python3
22
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3+
from typing import List
4+
35
import torch
6+
import torch.jit as jit
47
import torch.nn as nn
58
from caffe2.python.crf_predict import apply_crf
6-
from pytext.common.constants import Padding
7-
from pytext.utils.cuda import GetTensor
8-
from torch.autograd import Variable
99

1010

1111
class CRF(nn.Module):
@@ -17,7 +17,9 @@ class CRF(nn.Module):
1717
num_tags: The number of tags
1818
"""
1919

20-
def __init__(self, num_tags: int, ignore_index: int) -> None:
20+
def __init__(
21+
self, num_tags: int, ignore_index: int, default_label_pad_index: int
22+
) -> None:
2123
if num_tags <= 0:
2224
raise ValueError(f"Invalid number of tags: {num_tags}")
2325
super().__init__()
@@ -29,6 +31,7 @@ def __init__(self, num_tags: int, ignore_index: int) -> None:
2931
self.end_tag = num_tags + 1
3032
self.reset_parameters()
3133
self.ignore_index = ignore_index
34+
self.default_label_pad_index = default_label_pad_index
3235

3336
def reset_parameters(self) -> None:
3437
nn.init.uniform_(self.transitions, -0.1, 0.1)
@@ -42,8 +45,8 @@ def set_transitions(self, transitions: torch.Tensor = None):
4245
self.transitions.data = transitions
4346

4447
def forward(
45-
self, emissions: torch.FloatTensor, tags: torch.LongTensor, reduce: bool = True
46-
) -> Variable:
48+
self, emissions: torch.Tensor, tags: torch.Tensor, reduce: bool = True
49+
) -> torch.Tensor:
4750
"""
4851
Compute log-likelihood of input.
4952
@@ -62,9 +65,8 @@ def forward(
6265
llh = numerator - denominator
6366
return llh if not reduce else torch.mean(llh)
6467

65-
def decode(
66-
self, emissions: torch.FloatTensor, seq_lens: torch.LongTensor
67-
) -> torch.Tensor:
68+
@jit.export
69+
def decode(self, emissions: torch.Tensor, seq_lens: torch.Tensor) -> torch.Tensor:
6870
"""
6971
Given a set of emission probabilities, return the predicted tags.
7072
@@ -78,10 +80,7 @@ def decode(
7880
return result
7981

8082
def _compute_joint_llh(
81-
self,
82-
emissions: torch.FloatTensor,
83-
tags: torch.LongTensor,
84-
mask: torch.FloatTensor,
83+
self, emissions: torch.Tensor, tags: torch.Tensor, mask: torch.Tensor
8584
) -> torch.Tensor:
8685
seq_len = emissions.shape[1]
8786

@@ -115,7 +114,7 @@ def _compute_joint_llh(
115114
return llh.squeeze(1)
116115

117116
def _compute_log_partition_function(
118-
self, emissions: torch.FloatTensor, mask: torch.FloatTensor
117+
self, emissions: torch.Tensor, mask: torch.Tensor
119118
) -> torch.Tensor:
120119
seq_len = emissions.shape[1]
121120

@@ -139,8 +138,9 @@ def _compute_log_partition_function(
139138
return torch.logsumexp(log_prob.squeeze(1), 1)
140139

141140
def _viterbi_decode(
142-
self, emissions: torch.FloatTensor, mask: torch.FloatTensor
141+
self, emissions: torch.Tensor, mask: torch.Tensor
143142
) -> torch.Tensor:
143+
tensor_device = emissions.device
144144
seq_len = emissions.shape[1]
145145
mask = mask.to(torch.uint8)
146146

@@ -153,10 +153,12 @@ def _viterbi_decode(
153153
: self.start_tag, self.end_tag
154154
].unsqueeze(0)
155155

156-
best_scores_list = []
156+
best_scores_list: List[torch.Tensor] = []
157+
# Needed for Torchscript as empty list is assumed to be list of tensors
158+
empty_data: List[int] = []
157159
# If the element has only token, empty tensor in best_paths helps
158160
# torch.cat() from crashing
159-
best_paths_list = [GetTensor(torch.Tensor().long())]
161+
best_paths_list = [torch.tensor(empty_data, device=tensor_device).long()]
160162
best_scores_list.append(end_scores.unsqueeze(1))
161163

162164
for idx in range(1, seq_len):
@@ -185,12 +187,14 @@ def _viterbi_decode(
185187

186188
_, max_indices_from_scores = torch.max(best_scores, 2)
187189

188-
valid_index_tensor = GetTensor(torch.tensor(0)).long()
189-
if self.ignore_index == Padding.DEFAULT_LABEL_PAD_IDX:
190+
valid_index_tensor = torch.tensor(0, device=tensor_device).long()
191+
if self.ignore_index == self.default_label_pad_index:
190192
# No label for padding, so use 0 index.
191193
padding_tensor = valid_index_tensor
192194
else:
193-
padding_tensor = GetTensor(torch.tensor(self.ignore_index)).long()
195+
padding_tensor = torch.tensor(
196+
self.ignore_index, device=tensor_device
197+
).long()
194198

195199
# Label for the last position is always based on the index with max score
196200
# For illegal timesteps, we set as ignore_index
@@ -256,7 +260,7 @@ def _make_mask_from_targets(self, targets):
256260
def _make_mask_from_seq_lens(self, seq_lens):
257261
seq_lens = seq_lens.view(-1, 1)
258262
max_len = torch.max(seq_lens)
259-
range_tensor = GetTensor(torch.arange(max_len)).unsqueeze(0)
263+
range_tensor = torch.arange(max_len, device=seq_lens.device).unsqueeze(0)
260264
range_tensor = range_tensor.expand(seq_lens.size(0), range_tensor.size(1))
261265
mask = (range_tensor < seq_lens).float()
262266
return mask

pytext/models/decoders/intent_slot_model_decoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
33

4-
from typing import List, Optional
4+
from typing import List, Optional, Tuple
55

66
import torch
77
import torch.nn as nn
@@ -77,7 +77,7 @@ def __init__(
7777

7878
def forward(
7979
self, x_d: torch.Tensor, x_w: torch.Tensor, dense: Optional[torch.Tensor] = None
80-
) -> List[torch.Tensor]:
80+
) -> Tuple[torch.Tensor, torch.Tensor]:
8181
if dense is not None:
8282
logit_d = self.doc_decoder(torch.cat((x_d, dense), 1))
8383
else:
@@ -95,7 +95,7 @@ def forward(
9595
dense = dense.unsqueeze(1).repeat(1, word_input_shape[1], 1)
9696
x_w = torch.cat((x_w, dense), 2)
9797

98-
return [logit_d, self.word_decoder(x_w)]
98+
return logit_d, self.word_decoder(x_w)
9999

100100
def get_decoder(self) -> List[nn.Module]:
101101
"""Returns the document and word decoder modules.

pytext/models/output_layers/intent_slot_output_layer.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,43 @@
44
from typing import Any, Dict, List, Optional, Tuple, Union
55

66
import torch
7+
import torch.nn as nn
78
from caffe2.python import core
89
from pytext.common.constants import DatasetFieldName
910
from pytext.data.utils import Vocabulary
1011
from pytext.models.module import create_module
12+
from torch import jit
1113

1214
from .doc_classification_output_layer import ClassificationOutputLayer
1315
from .output_layer_base import OutputLayerBase
1416
from .word_tagging_output_layer import CRFOutputLayer, WordTaggingOutputLayer
1517

1618

19+
class IntentSlotScores(nn.Module):
20+
def __init__(self, doc_scores: jit.ScriptModule, word_scores: jit.ScriptModule):
21+
super().__init__()
22+
self.doc_scores = doc_scores
23+
self.word_scores = word_scores
24+
25+
def forward(
26+
self,
27+
logits: Tuple[torch.Tensor, torch.Tensor],
28+
seq_lengths: torch.Tensor,
29+
token_indices: Optional[torch.Tensor] = None,
30+
) -> Tuple[List[Dict[str, float]], List[List[Dict[str, float]]]]:
31+
d_logits, w_logits = logits
32+
if token_indices is not None:
33+
w_logits = torch.gather(
34+
w_logits,
35+
1,
36+
token_indices.unsqueeze(2).expand(-1, -1, w_logits.size(-1)),
37+
)
38+
39+
d_results = self.doc_scores(d_logits)
40+
w_results = self.word_scores(w_logits, seq_lengths)
41+
return d_results, w_results
42+
43+
1744
class IntentSlotOutputLayer(OutputLayerBase):
1845
"""
1946
Output layer for joint intent classification and slot-filling models.
@@ -161,3 +188,8 @@ def export_to_caffe2(
161188
) + self.word_output.export_to_caffe2(
162189
workspace, init_net, predict_net, model_out[1], word_out_name
163190
)
191+
192+
def torchscript_predictions(self):
193+
doc_scores = self.doc_output.torchscript_predictions()
194+
word_scores = self.word_output.torchscript_predictions()
195+
return jit.script(IntentSlotScores(doc_scores, word_scores))

pytext/models/output_layers/word_tagging_output_layer.py

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
from typing import Any, Dict, List, Optional, Tuple, Union
55

66
import torch
7+
import torch.jit as jit
8+
import torch.nn as nn
79
import torch.nn.functional as F
810
from caffe2.python import core
911
from pytext.common import Padding
1012
from pytext.config.component import create_loss
1113
from pytext.config.serialize import MissingValueError
1214
from pytext.data.utils import Vocabulary
13-
from pytext.fields import FieldMeta
1415
from pytext.loss import (
1516
AUCPRHingeLoss,
1617
BinaryCrossEntropyLoss,
@@ -26,6 +27,35 @@
2627
from .utils import OutputLayerUtils
2728

2829

30+
class WordTaggingScores(nn.Module):
31+
classes: List[str]
32+
33+
def __init__(self, classes):
34+
super().__init__()
35+
self.classes = classes
36+
37+
def forward(
38+
self, logits: torch.Tensor, seq_lengths: Optional[torch.Tensor] = None
39+
) -> List[List[Dict[str, float]]]:
40+
scores: torch.Tensor = F.log_softmax(logits, 2)
41+
return _get_prediction_from_scores(scores, self.classes)
42+
43+
44+
class CRFWordTaggingScores(WordTaggingScores):
45+
def __init__(self, classes: List[str], crf):
46+
super().__init__(classes)
47+
self.crf = crf
48+
self.crf.eval()
49+
50+
def forward(
51+
self, logits: torch.Tensor, seq_lengths: torch.Tensor
52+
) -> List[List[Dict[str, float]]]:
53+
pred = self.crf.decode(logits, seq_lengths)
54+
logits_rearranged = _rearrange_output(logits, pred)
55+
scores: torch.Tensor = F.log_softmax(logits_rearranged, 2)
56+
return _get_prediction_from_scores(scores, self.classes)
57+
58+
2959
class WordTaggingOutputLayer(OutputLayerBase):
3060
"""
3161
Output layer for word tagging models. It supports `CrossEntropyLoss` per word.
@@ -138,6 +168,9 @@ def export_to_caffe2(
138168
predict_net, probability_out, model_out, output_name, self.target_names
139169
)
140170

171+
def torchscript_predictions(self):
172+
return jit.script(WordTaggingScores(self.target_names))
173+
141174

142175
class CRFOutputLayer(OutputLayerBase):
143176
"""
@@ -160,7 +193,11 @@ def from_config(cls, config: OutputLayerBase.Config, labels: Vocabulary):
160193

161194
def __init__(self, num_tags, labels: Vocabulary, *args) -> None:
162195
super().__init__(list(labels), *args)
163-
self.crf = CRF(num_tags, labels.get_pad_index(Padding.DEFAULT_LABEL_PAD_IDX))
196+
self.crf = CRF(
197+
num_tags=num_tags,
198+
ignore_index=labels.get_pad_index(Padding.DEFAULT_LABEL_PAD_IDX),
199+
default_label_pad_index=Padding.DEFAULT_LABEL_PAD_IDX,
200+
)
164201

165202
def get_loss(
166203
self,
@@ -240,7 +277,11 @@ def export_to_caffe2(
240277
predict_net, probability_out, model_out, output_name, self.target_names
241278
)
242279

280+
def torchscript_predictions(self):
281+
return jit.script(CRFWordTaggingScores(self.target_names, jit.script(self.crf)))
243282

283+
284+
@jit.script
244285
def _rearrange_output(logit, pred):
245286
"""
246287
Rearrange the word logits so that the decoded word has the highest valued
@@ -252,3 +293,47 @@ def _rearrange_output(logit, pred):
252293
logit_rearranged = logit.scatter(2, pred_indices, max_logits)
253294
logit_rearranged.scatter_(2, max_logit_indices, pred_logits)
254295
return logit_rearranged
296+
297+
298+
@jit.script
299+
def _get_prediction_from_scores(
300+
scores: torch.Tensor, classes: List[str]
301+
) -> List[List[Dict[str, float]]]:
302+
"""
303+
Given scores for a batch, get the prediction for each word in the form of a
304+
List[List[Dict[str, float]]] for callers of the torchscript model to consume.
305+
The outer list iterates over batches of sentences and the inner iterates
306+
over each token in the sentence. The dictionary consists of
307+
`label:score` for each word.
308+
309+
Example:
310+
311+
Assuming slot labels are [No-Label, Number, Name]
312+
Utterances: [[call john please], [Brightness 25]]
313+
Output could look like:
314+
[
315+
[
316+
{ No-Label: -0.1, Number: -1.5, Name: -9.01},
317+
{ No-Label: -2.1, Number: -1.5, Name: -0.01},
318+
{ No-Label: -0.1, Number: -1.5, Name: -2.01},
319+
],
320+
[
321+
{ No-Label: -0.1, Number: -1.5, Name: -9.01},
322+
{ No-Label: -2.1, Number: -0.5, Name: -7.01},
323+
{ No-Label: -0.1, Number: -1.5, Name: -2.01},
324+
]
325+
]
326+
"""
327+
results: List[List[Dict[str, float]]] = []
328+
# Extra verbosity because jit doesn't support zip
329+
for sentence_scores in scores.chunk(len(scores)):
330+
sentence_scores = sentence_scores.squeeze(0)
331+
sentence_response: List[Dict[str, float]] = []
332+
for word_scores in sentence_scores.chunk(len(sentence_scores)):
333+
word_scores = word_scores.squeeze(0)
334+
word_response: Dict[str, float] = {}
335+
for i in range(len(classes)):
336+
word_response[classes[i]] = float(word_scores[i].item())
337+
sentence_response.append(word_response)
338+
results.append(sentence_response)
339+
return results

0 commit comments

Comments
 (0)