Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

Commit f32b9ea

Browse files
jingfeidufacebook-github-bot
authored andcommitted
modify accuracy calculation for multi-label classification
Summary: Previously we didn't count TN as correct prediction. Counting this to make it consistent with how we calculate losses Reviewed By: chenyangyu1988 Differential Revision: D19798403 fbshipit-source-id: 7430b2250c8fb1a11877083f581c5bdc2f82b362
1 parent f7e1ff4 commit f32b9ea

File tree

1 file changed

+12
-16
lines changed

1 file changed

+12
-16
lines changed

pytext/metrics/__init__.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -914,24 +914,20 @@ def compute_multi_label_classification_metrics(
914914
num_expected_labels = 0
915915
per_label_confusions = PerLabelConfusions()
916916
for _, predicted, expected in predictions:
917-
# "predicted" is in the format of n_hot_encoding
918-
# Calculate TP & FN
919-
for true_label_idx in expected:
920-
if true_label_idx < 0:
921-
# padded label "-1"
922-
break
917+
for label_idx, label_name in enumerate(label_names):
923918
num_expected_labels += 1
924-
expected_label = label_names[true_label_idx]
925-
if predicted[true_label_idx] == 1:
926-
num_correct += 1
927-
per_label_confusions.update(expected_label, "TP", 1)
919+
# "predicted" is in the format of n_hot_encoding
920+
if predicted[label_idx] == 1:
921+
if label_idx in expected: # TP
922+
num_correct += 1
923+
per_label_confusions.update(label_name, "TP", 1)
924+
else: # FP
925+
per_label_confusions.update(label_name, "FP", 1)
928926
else:
929-
per_label_confusions.update(expected_label, "FN", 1)
930-
# Calculate FP
931-
for idx, pred in enumerate(predicted):
932-
if pred == 1 and idx not in expected:
933-
predicted_label = label_names[idx]
934-
per_label_confusions.update(predicted_label, "FP", 1)
927+
if label_idx in expected: # FN
928+
per_label_confusions.update(label_name, "FN", 1)
929+
else: # TN, update correct num
930+
num_correct += 1
935931

936932
accuracy = safe_division(num_correct, num_expected_labels)
937933
macro_prf1_metrics = per_label_confusions.compute_metrics()

0 commit comments

Comments
 (0)