Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Add support for Torchscript export of IntentSlotOutputLayer and CRF #1146

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 25 additions & 21 deletions pytext/models/crf.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List

import torch
import torch.jit as jit
import torch.nn as nn
from caffe2.python.crf_predict import apply_crf
from pytext.common.constants import Padding
from pytext.utils.cuda import GetTensor
from torch.autograd import Variable


class CRF(nn.Module):
Expand All @@ -17,7 +17,9 @@ class CRF(nn.Module):
num_tags: The number of tags
"""

def __init__(self, num_tags: int, ignore_index: int) -> None:
def __init__(
self, num_tags: int, ignore_index: int, default_label_pad_index: int
) -> None:
if num_tags <= 0:
raise ValueError(f"Invalid number of tags: {num_tags}")
super().__init__()
Expand All @@ -29,6 +31,7 @@ def __init__(self, num_tags: int, ignore_index: int) -> None:
self.end_tag = num_tags + 1
self.reset_parameters()
self.ignore_index = ignore_index
self.default_label_pad_index = default_label_pad_index

def reset_parameters(self) -> None:
nn.init.uniform_(self.transitions, -0.1, 0.1)
Expand All @@ -42,8 +45,8 @@ def set_transitions(self, transitions: torch.Tensor = None):
self.transitions.data = transitions

def forward(
self, emissions: torch.FloatTensor, tags: torch.LongTensor, reduce: bool = True
) -> Variable:
self, emissions: torch.Tensor, tags: torch.Tensor, reduce: bool = True
) -> torch.Tensor:
"""
Compute log-likelihood of input.

Expand All @@ -62,9 +65,8 @@ def forward(
llh = numerator - denominator
return llh if not reduce else torch.mean(llh)

def decode(
self, emissions: torch.FloatTensor, seq_lens: torch.LongTensor
) -> torch.Tensor:
@jit.export
def decode(self, emissions: torch.Tensor, seq_lens: torch.Tensor) -> torch.Tensor:
"""
Given a set of emission probabilities, return the predicted tags.

Expand All @@ -78,10 +80,7 @@ def decode(
return result

def _compute_joint_llh(
self,
emissions: torch.FloatTensor,
tags: torch.LongTensor,
mask: torch.FloatTensor,
self, emissions: torch.Tensor, tags: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
seq_len = emissions.shape[1]

Expand Down Expand Up @@ -115,7 +114,7 @@ def _compute_joint_llh(
return llh.squeeze(1)

def _compute_log_partition_function(
self, emissions: torch.FloatTensor, mask: torch.FloatTensor
self, emissions: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
seq_len = emissions.shape[1]

Expand All @@ -139,8 +138,9 @@ def _compute_log_partition_function(
return torch.logsumexp(log_prob.squeeze(1), 1)

def _viterbi_decode(
self, emissions: torch.FloatTensor, mask: torch.FloatTensor
self, emissions: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
tensor_device = emissions.device
seq_len = emissions.shape[1]
mask = mask.to(torch.uint8)

Expand All @@ -153,10 +153,12 @@ def _viterbi_decode(
: self.start_tag, self.end_tag
].unsqueeze(0)

best_scores_list = []
best_scores_list: List[torch.Tensor] = []
# Needed for Torchscript as empty list is assumed to be list of tensors
empty_data: List[int] = []
# If the element has only token, empty tensor in best_paths helps
# torch.cat() from crashing
best_paths_list = [GetTensor(torch.Tensor().long())]
best_paths_list = [torch.tensor(empty_data, device=tensor_device).long()]
best_scores_list.append(end_scores.unsqueeze(1))

for idx in range(1, seq_len):
Expand Down Expand Up @@ -185,12 +187,14 @@ def _viterbi_decode(

_, max_indices_from_scores = torch.max(best_scores, 2)

valid_index_tensor = GetTensor(torch.tensor(0)).long()
if self.ignore_index == Padding.DEFAULT_LABEL_PAD_IDX:
valid_index_tensor = torch.tensor(0, device=tensor_device).long()
if self.ignore_index == self.default_label_pad_index:
# No label for padding, so use 0 index.
padding_tensor = valid_index_tensor
else:
padding_tensor = GetTensor(torch.tensor(self.ignore_index)).long()
padding_tensor = torch.tensor(
self.ignore_index, device=tensor_device
).long()

# Label for the last position is always based on the index with max score
# For illegal timesteps, we set as ignore_index
Expand Down Expand Up @@ -256,7 +260,7 @@ def _make_mask_from_targets(self, targets):
def _make_mask_from_seq_lens(self, seq_lens):
seq_lens = seq_lens.view(-1, 1)
max_len = torch.max(seq_lens)
range_tensor = GetTensor(torch.arange(max_len)).unsqueeze(0)
range_tensor = torch.arange(max_len, device=seq_lens.device).unsqueeze(0)
range_tensor = range_tensor.expand(seq_lens.size(0), range_tensor.size(1))
mask = (range_tensor < seq_lens).float()
return mask
Expand Down
6 changes: 3 additions & 3 deletions pytext/models/decoders/intent_slot_model_decoder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import List, Optional
from typing import List, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(

def forward(
self, x_d: torch.Tensor, x_w: torch.Tensor, dense: Optional[torch.Tensor] = None
) -> List[torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor]:
if dense is not None:
logit_d = self.doc_decoder(torch.cat((x_d, dense), 1))
else:
Expand All @@ -95,7 +95,7 @@ def forward(
dense = dense.unsqueeze(1).repeat(1, word_input_shape[1], 1)
x_w = torch.cat((x_w, dense), 2)

return [logit_d, self.word_decoder(x_w)]
return logit_d, self.word_decoder(x_w)

def get_decoder(self) -> List[nn.Module]:
"""Returns the document and word decoder modules.
Expand Down
32 changes: 32 additions & 0 deletions pytext/models/output_layers/intent_slot_output_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,43 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from caffe2.python import core
from pytext.common.constants import DatasetFieldName
from pytext.data.utils import Vocabulary
from pytext.models.module import create_module
from torch import jit

from .doc_classification_output_layer import ClassificationOutputLayer
from .output_layer_base import OutputLayerBase
from .word_tagging_output_layer import CRFOutputLayer, WordTaggingOutputLayer


class IntentSlotScores(nn.Module):
def __init__(self, doc_scores: jit.ScriptModule, word_scores: jit.ScriptModule):
super().__init__()
self.doc_scores = doc_scores
self.word_scores = word_scores

def forward(
self,
logits: Tuple[torch.Tensor, torch.Tensor],
seq_lengths: torch.Tensor,
token_indices: Optional[torch.Tensor] = None,
) -> Tuple[List[Dict[str, float]], List[List[Dict[str, float]]]]:
d_logits, w_logits = logits
if token_indices is not None:
w_logits = torch.gather(
w_logits,
1,
token_indices.unsqueeze(2).expand(-1, -1, w_logits.size(-1)),
)

d_results = self.doc_scores(d_logits)
w_results = self.word_scores(w_logits, seq_lengths)
return d_results, w_results


class IntentSlotOutputLayer(OutputLayerBase):
"""
Output layer for joint intent classification and slot-filling models.
Expand Down Expand Up @@ -161,3 +188,8 @@ def export_to_caffe2(
) + self.word_output.export_to_caffe2(
workspace, init_net, predict_net, model_out[1], word_out_name
)

def torchscript_predictions(self):
doc_scores = self.doc_output.torchscript_predictions()
word_scores = self.word_output.torchscript_predictions()
return jit.script(IntentSlotScores(doc_scores, word_scores))
89 changes: 87 additions & 2 deletions pytext/models/output_layers/word_tagging_output_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.jit as jit
import torch.nn as nn
import torch.nn.functional as F
from caffe2.python import core
from pytext.common import Padding
from pytext.config.component import create_loss
from pytext.config.serialize import MissingValueError
from pytext.data.utils import Vocabulary
from pytext.fields import FieldMeta
from pytext.loss import (
AUCPRHingeLoss,
BinaryCrossEntropyLoss,
Expand All @@ -26,6 +27,35 @@
from .utils import OutputLayerUtils


class WordTaggingScores(nn.Module):
classes: List[str]

def __init__(self, classes):
super().__init__()
self.classes = classes

def forward(
self, logits: torch.Tensor, seq_lengths: Optional[torch.Tensor] = None
) -> List[List[Dict[str, float]]]:
scores: torch.Tensor = F.log_softmax(logits, 2)
return _get_prediction_from_scores(scores, self.classes)


class CRFWordTaggingScores(WordTaggingScores):
def __init__(self, classes: List[str], crf):
super().__init__(classes)
self.crf = crf
self.crf.eval()

def forward(
self, logits: torch.Tensor, seq_lengths: torch.Tensor
) -> List[List[Dict[str, float]]]:
pred = self.crf.decode(logits, seq_lengths)
logits_rearranged = _rearrange_output(logits, pred)
scores: torch.Tensor = F.log_softmax(logits_rearranged, 2)
return _get_prediction_from_scores(scores, self.classes)


class WordTaggingOutputLayer(OutputLayerBase):
"""
Output layer for word tagging models. It supports `CrossEntropyLoss` per word.
Expand Down Expand Up @@ -138,6 +168,9 @@ def export_to_caffe2(
predict_net, probability_out, model_out, output_name, self.target_names
)

def torchscript_predictions(self):
return jit.script(WordTaggingScores(self.target_names))


class CRFOutputLayer(OutputLayerBase):
"""
Expand All @@ -160,7 +193,11 @@ def from_config(cls, config: OutputLayerBase.Config, labels: Vocabulary):

def __init__(self, num_tags, labels: Vocabulary, *args) -> None:
super().__init__(list(labels), *args)
self.crf = CRF(num_tags, labels.get_pad_index(Padding.DEFAULT_LABEL_PAD_IDX))
self.crf = CRF(
num_tags=num_tags,
ignore_index=labels.get_pad_index(Padding.DEFAULT_LABEL_PAD_IDX),
default_label_pad_index=Padding.DEFAULT_LABEL_PAD_IDX,
)

def get_loss(
self,
Expand Down Expand Up @@ -240,7 +277,11 @@ def export_to_caffe2(
predict_net, probability_out, model_out, output_name, self.target_names
)

def torchscript_predictions(self):
return jit.script(CRFWordTaggingScores(self.target_names, jit.script(self.crf)))


@jit.script
def _rearrange_output(logit, pred):
"""
Rearrange the word logits so that the decoded word has the highest valued
Expand All @@ -252,3 +293,47 @@ def _rearrange_output(logit, pred):
logit_rearranged = logit.scatter(2, pred_indices, max_logits)
logit_rearranged.scatter_(2, max_logit_indices, pred_logits)
return logit_rearranged


@jit.script
def _get_prediction_from_scores(
scores: torch.Tensor, classes: List[str]
) -> List[List[Dict[str, float]]]:
"""
Given scores for a batch, get the prediction for each word in the form of a
List[List[Dict[str, float]]] for callers of the torchscript model to consume.
The outer list iterates over batches of sentences and the inner iterates
over each token in the sentence. The dictionary consists of
`label:score` for each word.

Example:

Assuming slot labels are [No-Label, Number, Name]
Utterances: [[call john please], [Brightness 25]]
Output could look like:
[
[
{ No-Label: -0.1, Number: -1.5, Name: -9.01},
{ No-Label: -2.1, Number: -1.5, Name: -0.01},
{ No-Label: -0.1, Number: -1.5, Name: -2.01},
],
[
{ No-Label: -0.1, Number: -1.5, Name: -9.01},
{ No-Label: -2.1, Number: -0.5, Name: -7.01},
{ No-Label: -0.1, Number: -1.5, Name: -2.01},
]
]
"""
results: List[List[Dict[str, float]]] = []
# Extra verbosity because jit doesn't support zip
for sentence_scores in scores.chunk(len(scores)):
sentence_scores = sentence_scores.squeeze(0)
sentence_response: List[Dict[str, float]] = []
for word_scores in sentence_scores.chunk(len(sentence_scores)):
word_scores = word_scores.squeeze(0)
word_response: Dict[str, float] = {}
for i in range(len(classes)):
word_response[classes[i]] = float(word_scores[i].item())
sentence_response.append(word_response)
results.append(sentence_response)
return results
Loading