Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

MultiLabel-MultiClass Model for Joint Sequence Tagging #1335

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytext/data/tensorizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,8 +853,8 @@ def numberize(self, row):
label_idx_list.append(self.pad_idx)
else:
raise Exception(
"Found none or empty value in the list,"
+ " while pad_missing is disabled"
"Found none or empty value in the list, \
while pad_missing is disabled"
)
else:
label_idx_list.append(self.vocab.lookup_all(label))
Expand Down
2 changes: 2 additions & 0 deletions pytext/metric_reporters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .regression_metric_reporter import RegressionMetricReporter
from .squad_metric_reporter import SquadMetricReporter
from .word_tagging_metric_reporter import (
MultiLabelSequenceTaggingMetricReporter,
NERMetricReporter,
SequenceTaggingMetricReporter,
WordTaggingMetricReporter,
Expand All @@ -26,6 +27,7 @@
"MetricReporter",
"ClassificationMetricReporter",
"MultiLabelClassificationMetricReporter",
"MultiLabelSequenceTaggingMetricReporter",
"RegressionMetricReporter",
"IntentSlotMetricReporter",
"LanguageModelMetricReporter",
Expand Down
68 changes: 68 additions & 0 deletions pytext/metric_reporters/word_tagging_metric_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
LabelPrediction,
PRF1Metrics,
compute_classification_metrics,
compute_multi_label_multi_class_soft_metrics,
)
from pytext.metrics.intent_slot_metrics import (
Node,
Expand Down Expand Up @@ -92,6 +93,73 @@ def get_model_select_metric(self, metrics):
return metrics.micro_scores.f1


class MultiLabelSequenceTaggingMetricReporter(MetricReporter):
def __init__(self, label_names, pad_idx, channels, label_vocabs=None):
super().__init__(channels)
self.label_names = label_names
self.pad_idx = pad_idx
self.label_vocabs = label_vocabs

@classmethod
def from_config(cls, config, tensorizers):
return MultiLabelSequenceTaggingMetricReporter(
channels=[ConsoleChannel(), FileChannel((Stage.TEST,), config.output_path)],
label_names=tensorizers.keys(),
pad_idx=[v.pad_idx for _, v in tensorizers.items()],
label_vocabs=[v.vocab._vocab for _, v in tensorizers.items()],
)

def calculate_metric(self):
if len(self.all_scores) == 0:
return {}
list_score_pred_expect = []
for label_idx in range(0, len(self.label_names)):
list_score_pred_expect.append(
list(
itertools.chain.from_iterable(
(
LabelPrediction(s, p, e)
for s, p, e in zip(scores, pred, expect)
if e != self.pad_idx[label_idx]
)
for scores, pred, expect in zip(
self.all_scores[label_idx],
self.all_preds[label_idx],
self.all_targets[label_idx],
)
)
)
)
metrics = compute_multi_label_multi_class_soft_metrics(
list_score_pred_expect,
self.label_names,
self.label_vocabs,
self.calculate_loss(),
)
return metrics

def batch_context(self, raw_batch, batch):
return {}

@staticmethod
def get_model_select_metric(metrics):
if isinstance(metrics, dict):
# There are multiclass precision/recall labels
# Compute average precision
avg_precision = 0.0
for _, metric in metrics.items():
if metric:
avg_precision += sum(
v.average_precision
for k, v in metric.items()
if v.average_precision > 0
) / (len(metric.keys()) * 1.0)
avg_precision = avg_precision / (len(metrics.keys()) * 1.0)
else:
avg_precision = metrics.accuracy
return avg_precision


class SequenceTaggingMetricReporter(MetricReporter):
def __init__(self, label_names, pad_idx, channels):
super().__init__(channels)
Expand Down
31 changes: 31 additions & 0 deletions pytext/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,37 @@ def compute_multi_label_soft_metrics(
return soft_metrics


def compute_multi_label_multi_class_soft_metrics(
predictions: Sequence[Sequence[LabelListPrediction]],
label_names: Sequence[str],
label_vocabs: Sequence[Sequence[str]],
recall_at_precision_thresholds: Sequence[float] = RECALL_AT_PRECISION_THRESHOLDS,
precision_at_recall_thresholds: Sequence[float] = PRECISION_AT_RECALL_THRESHOLDS,
) -> Dict[int, SoftClassificationMetrics]:
"""

Computes multi-label soft classification metrics with multi-class accommodation

Args:
predictions: multi-label predictions,
including the confidence score for each label.
label_names: Indexed label names.
recall_at_precision_thresholds: precision thresholds at which to calculate
recall
precision_at_recall_thresholds: recall thresholds at which to calculate
precision


Returns:
Dict from label strings to their corresponding soft metrics.
"""
soft_metrics = {}
for label_idx, label_vocab in enumerate(label_vocabs):
label = list(label_names)[label_idx]
soft_metrics[label] = compute_soft_metrics(predictions[label_idx], label_vocab)
return soft_metrics


def compute_matthews_correlation_coefficients(
TP: int, FP: int, FN: int, TN: int
) -> float:
Expand Down
64 changes: 64 additions & 0 deletions pytext/models/decoders/multilabel_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from typing import Dict, List

import torch
import torch.nn as nn
from pytext.utils.usage import log_class_usage

from .decoder_base import DecoderBase


class MultiLabelDecoder(DecoderBase):
"""
Implements a 'n-tower' MLP: one for each of the multi labels
Used in USM/EA: the user satisfaction modeling, pTSR prediction and
Error Attribution are all 3 label sets that need predicting.

"""

class Config(DecoderBase.Config):
# Intermediate hidden dimensions
hidden_dims: List[int] = []

def __init__(
self,
config: Config,
in_dim: int,
output_dim: Dict[str, int],
label_names: List[str],
) -> None:
super().__init__(config)
self.label_mlps = nn.ModuleDict({})
# Store the ordered list to preserve the ordering of the labels
# when generating the output layer
self.label_names = label_names
aggregate_out_dim = 0
for label_, _ in output_dim.items():
self.label_mlps[label_] = MultiLabelDecoder.get_mlp(
in_dim, output_dim[label_], config.hidden_dims
)
aggregate_out_dim += output_dim[label_]
self.out_dim = (1, aggregate_out_dim)
log_class_usage(__class__)

@staticmethod
def get_mlp(in_dim: int, out_dim: int, hidden_dims: List[int]):
layers = []
current_dim = in_dim
for dim in hidden_dims or []:
layers.append(nn.Linear(current_dim, dim))
layers.append(nn.ReLU())
current_dim = dim
layers.append(nn.Linear(current_dim, out_dim))
return nn.Sequential(*layers)

def forward(self, *input: torch.Tensor):
logits = tuple(
self.label_mlps[x](torch.cat(input, 1)) for x in self.label_names
)
return logits

def get_decoder(self) -> List[nn.Module]:
return self.label_mlps
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def forward(self, logits: torch.Tensor):
class BinaryClassificationOutputLayer(ClassificationOutputLayer):
def get_pred(self, logit, *args, **kwargs):
"""See `OutputLayerBase.get_pred()`."""
preds = torch.max(logit, 1)[1]
preds = torch.max(logit, -1)[1]
scores = F.logsigmoid(logit)
return preds, scores

Expand All @@ -153,7 +153,7 @@ def export_to_caffe2(
class MulticlassOutputLayer(ClassificationOutputLayer):
def get_pred(self, logit, *args, **kwargs):
"""See `OutputLayerBase.get_pred()`."""
preds = torch.max(logit, 1)[1]
preds = torch.max(logit, -1)[1]
scores = F.log_softmax(logit, 1)
return preds, scores

Expand Down
Loading