Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit 7d07750

Browse files
shivanipodsfacebook-github-bot
authored andcommitted
MultiLabel-MultiClass Model for Joint Sequence Tagging (#1335)
Summary: Pull Request resolved: #1335 We need to support multi-class as well as multi-label prediction for joint models in pytext. This diff implements a 1. Joint Multi Label Decoder 2. MultiLabelClassification Output Layer 3. Loss computation for multi-label-multi-class scenarios 4. Label weights per label and per class 5. Softmax options for output layers 6. Custom Metric Reporter, Metric Class and Output for flow Reviewed By: seayoung1112 Differential Revision: D20210880 fbshipit-source-id: 62e4e43bd962a0bfd44f7f54c7c484ebfaf6037c
1 parent 189ca70 commit 7d07750

File tree

10 files changed

+370
-6
lines changed

10 files changed

+370
-6
lines changed

pytext/data/tensorizers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -853,8 +853,8 @@ def numberize(self, row):
853853
label_idx_list.append(self.pad_idx)
854854
else:
855855
raise Exception(
856-
"Found none or empty value in the list,"
857-
+ " while pad_missing is disabled"
856+
"Found none or empty value in the list, \
857+
while pad_missing is disabled"
858858
)
859859
else:
860860
label_idx_list.append(self.vocab.lookup_all(label))

pytext/metric_reporters/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .regression_metric_reporter import RegressionMetricReporter
1616
from .squad_metric_reporter import SquadMetricReporter
1717
from .word_tagging_metric_reporter import (
18+
MultiLabelSequenceTaggingMetricReporter,
1819
NERMetricReporter,
1920
SequenceTaggingMetricReporter,
2021
WordTaggingMetricReporter,
@@ -26,6 +27,7 @@
2627
"MetricReporter",
2728
"ClassificationMetricReporter",
2829
"MultiLabelClassificationMetricReporter",
30+
"MultiLabelSequenceTaggingMetricReporter",
2931
"RegressionMetricReporter",
3032
"IntentSlotMetricReporter",
3133
"LanguageModelMetricReporter",

pytext/metric_reporters/word_tagging_metric_reporter.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
LabelPrediction,
1414
PRF1Metrics,
1515
compute_classification_metrics,
16+
compute_multi_label_multi_class_soft_metrics,
1617
)
1718
from pytext.metrics.intent_slot_metrics import (
1819
Node,
@@ -92,6 +93,73 @@ def get_model_select_metric(self, metrics):
9293
return metrics.micro_scores.f1
9394

9495

96+
class MultiLabelSequenceTaggingMetricReporter(MetricReporter):
97+
def __init__(self, label_names, pad_idx, channels, label_vocabs=None):
98+
super().__init__(channels)
99+
self.label_names = label_names
100+
self.pad_idx = pad_idx
101+
self.label_vocabs = label_vocabs
102+
103+
@classmethod
104+
def from_config(cls, config, tensorizers):
105+
return MultiLabelSequenceTaggingMetricReporter(
106+
channels=[ConsoleChannel(), FileChannel((Stage.TEST,), config.output_path)],
107+
label_names=tensorizers.keys(),
108+
pad_idx=[v.pad_idx for _, v in tensorizers.items()],
109+
label_vocabs=[v.vocab._vocab for _, v in tensorizers.items()],
110+
)
111+
112+
def calculate_metric(self):
113+
if len(self.all_scores) == 0:
114+
return {}
115+
list_score_pred_expect = []
116+
for label_idx in range(0, len(self.label_names)):
117+
list_score_pred_expect.append(
118+
list(
119+
itertools.chain.from_iterable(
120+
(
121+
LabelPrediction(s, p, e)
122+
for s, p, e in zip(scores, pred, expect)
123+
if e != self.pad_idx[label_idx]
124+
)
125+
for scores, pred, expect in zip(
126+
self.all_scores[label_idx],
127+
self.all_preds[label_idx],
128+
self.all_targets[label_idx],
129+
)
130+
)
131+
)
132+
)
133+
metrics = compute_multi_label_multi_class_soft_metrics(
134+
list_score_pred_expect,
135+
self.label_names,
136+
self.label_vocabs,
137+
self.calculate_loss(),
138+
)
139+
return metrics
140+
141+
def batch_context(self, raw_batch, batch):
142+
return {}
143+
144+
@staticmethod
145+
def get_model_select_metric(metrics):
146+
if isinstance(metrics, dict):
147+
# There are multiclass precision/recall labels
148+
# Compute average precision
149+
avg_precision = 0.0
150+
for _, metric in metrics.items():
151+
if metric:
152+
avg_precision += sum(
153+
v.average_precision
154+
for k, v in metric.items()
155+
if v.average_precision > 0
156+
) / (len(metric.keys()) * 1.0)
157+
avg_precision = avg_precision / (len(metrics.keys()) * 1.0)
158+
else:
159+
avg_precision = metrics.accuracy
160+
return avg_precision
161+
162+
95163
class SequenceTaggingMetricReporter(MetricReporter):
96164
def __init__(self, label_names, pad_idx, channels):
97165
super().__init__(channels)

pytext/metrics/__init__.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,37 @@ def compute_multi_label_soft_metrics(
753753
return soft_metrics
754754

755755

756+
def compute_multi_label_multi_class_soft_metrics(
757+
predictions: Sequence[Sequence[LabelListPrediction]],
758+
label_names: Sequence[str],
759+
label_vocabs: Sequence[Sequence[str]],
760+
recall_at_precision_thresholds: Sequence[float] = RECALL_AT_PRECISION_THRESHOLDS,
761+
precision_at_recall_thresholds: Sequence[float] = PRECISION_AT_RECALL_THRESHOLDS,
762+
) -> Dict[int, SoftClassificationMetrics]:
763+
"""
764+
765+
Computes multi-label soft classification metrics with multi-class accommodation
766+
767+
Args:
768+
predictions: multi-label predictions,
769+
including the confidence score for each label.
770+
label_names: Indexed label names.
771+
recall_at_precision_thresholds: precision thresholds at which to calculate
772+
recall
773+
precision_at_recall_thresholds: recall thresholds at which to calculate
774+
precision
775+
776+
777+
Returns:
778+
Dict from label strings to their corresponding soft metrics.
779+
"""
780+
soft_metrics = {}
781+
for label_idx, label_vocab in enumerate(label_vocabs):
782+
label = list(label_names)[label_idx]
783+
soft_metrics[label] = compute_soft_metrics(predictions[label_idx], label_vocab)
784+
return soft_metrics
785+
786+
756787
def compute_matthews_correlation_coefficients(
757788
TP: int, FP: int, FN: int, TN: int
758789
) -> float:
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
3+
4+
from typing import Dict, List
5+
6+
import torch
7+
import torch.nn as nn
8+
from pytext.utils.usage import log_class_usage
9+
10+
from .decoder_base import DecoderBase
11+
12+
13+
class MultiLabelDecoder(DecoderBase):
14+
"""
15+
Implements a 'n-tower' MLP: one for each of the multi labels
16+
Used in USM/EA: the user satisfaction modeling, pTSR prediction and
17+
Error Attribution are all 3 label sets that need predicting.
18+
19+
"""
20+
21+
class Config(DecoderBase.Config):
22+
# Intermediate hidden dimensions
23+
hidden_dims: List[int] = []
24+
25+
def __init__(
26+
self,
27+
config: Config,
28+
in_dim: int,
29+
output_dim: Dict[str, int],
30+
label_names: List[str],
31+
) -> None:
32+
super().__init__(config)
33+
self.label_mlps = nn.ModuleDict({})
34+
# Store the ordered list to preserve the ordering of the labels
35+
# when generating the output layer
36+
self.label_names = label_names
37+
aggregate_out_dim = 0
38+
for label_, _ in output_dim.items():
39+
self.label_mlps[label_] = MultiLabelDecoder.get_mlp(
40+
in_dim, output_dim[label_], config.hidden_dims
41+
)
42+
aggregate_out_dim += output_dim[label_]
43+
self.out_dim = (1, aggregate_out_dim)
44+
log_class_usage(__class__)
45+
46+
@staticmethod
47+
def get_mlp(in_dim: int, out_dim: int, hidden_dims: List[int]):
48+
layers = []
49+
current_dim = in_dim
50+
for dim in hidden_dims or []:
51+
layers.append(nn.Linear(current_dim, dim))
52+
layers.append(nn.ReLU())
53+
current_dim = dim
54+
layers.append(nn.Linear(current_dim, out_dim))
55+
return nn.Sequential(*layers)
56+
57+
def forward(self, *input: torch.Tensor):
58+
logits = tuple(
59+
self.label_mlps[x](torch.cat(input, 1)) for x in self.label_names
60+
)
61+
return logits
62+
63+
def get_decoder(self) -> List[nn.Module]:
64+
return self.label_mlps

pytext/models/output_layers/doc_classification_output_layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def forward(self, logits: torch.Tensor):
128128
class BinaryClassificationOutputLayer(ClassificationOutputLayer):
129129
def get_pred(self, logit, *args, **kwargs):
130130
"""See `OutputLayerBase.get_pred()`."""
131-
preds = torch.max(logit, 1)[1]
131+
preds = torch.max(logit, -1)[1]
132132
scores = F.logsigmoid(logit)
133133
return preds, scores
134134

@@ -153,7 +153,7 @@ def export_to_caffe2(
153153
class MulticlassOutputLayer(ClassificationOutputLayer):
154154
def get_pred(self, logit, *args, **kwargs):
155155
"""See `OutputLayerBase.get_pred()`."""
156-
preds = torch.max(logit, 1)[1]
156+
preds = torch.max(logit, -1)[1]
157157
scores = F.log_softmax(logit, 1)
158158
return preds, scores
159159

0 commit comments

Comments
 (0)