Skip to content

Commit bd36919

Browse files
authored
📏 Add PRO metric (#508)
* remove cpu warning * add pro metric test * remove pylint ignore statements * fix component labeling bug * add more tests for pro metric and ccomp labeling * use kornia for ccomp labeling * fix aupro tests * address PR comments
1 parent a03e592 commit bd36919

File tree

5 files changed

+207
-13
lines changed

5 files changed

+207
-13
lines changed

anomalib/utils/metrics/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
from .collection import AnomalibMetricCollection
1919
from .min_max import MinMax
2020
from .optimal_f1 import OptimalF1
21+
from .pro import PRO
2122

22-
__all__ = ["AUROC", "AUPR", "AUPRO", "OptimalF1", "AdaptiveThreshold", "AnomalyScoreDistribution", "MinMax"]
23+
__all__ = ["AUROC", "AUPR", "AUPRO", "OptimalF1", "AdaptiveThreshold", "AnomalyScoreDistribution", "MinMax", "PRO"]
2324

2425

2526
def get_metrics(config: Union[ListConfig, DictConfig]) -> Tuple[AnomalibMetricCollection, AnomalibMetricCollection]:

anomalib/utils/metrics/aupro.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@
66
from typing import Any, Callable, List, Optional, Tuple
77

88
import torch
9-
from kornia.contrib import connected_components
109
from matplotlib.figure import Figure
1110
from torch import Tensor
1211
from torchmetrics import Metric
1312
from torchmetrics.functional import auc, roc
1413
from torchmetrics.utilities.data import dim_zero_cat
1514

15+
from anomalib.utils.metrics.pro import (
16+
connected_components_cpu,
17+
connected_components_gpu,
18+
)
19+
1620
from .plotting_utils import plot_figure
1721

1822

@@ -80,9 +84,10 @@ def _compute(self) -> Tuple[Tensor, Tensor]:
8084
)
8185
target = target.unsqueeze(1) # kornia expects N1HW format
8286
target = target.type(torch.float) # kornia expects FloatTensor
83-
cca = connected_components(
84-
target, num_iterations=1000
85-
) # Need higher thresholds this to avoid oversegmentation.
87+
if target.is_cuda:
88+
cca = connected_components_gpu(target)
89+
else:
90+
cca = connected_components_cpu(target)
8691

8792
preds = preds.flatten()
8893
cca = cca.flatten()

anomalib/utils/metrics/pro.py

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""Implementation of PRO metric based on TorchMetrics."""
2+
3+
# Copyright (C) 2022 Intel Corporation
4+
# SPDX-License-Identifier: Apache-2.0
5+
6+
from typing import List
7+
8+
import cv2
9+
import numpy as np
10+
import torch
11+
from kornia.contrib import connected_components
12+
from torch import Tensor
13+
from torchmetrics import Metric
14+
from torchmetrics.functional import recall
15+
from torchmetrics.utilities.data import dim_zero_cat
16+
17+
18+
class PRO(Metric):
19+
"""Per-Region Overlap (PRO) Score."""
20+
21+
target: List[Tensor]
22+
preds: List[Tensor]
23+
24+
def __init__(self, threshold: float = 0.5, **kwargs) -> None:
25+
super().__init__(**kwargs)
26+
self.threshold = threshold
27+
28+
self.add_state("preds", default=[], dist_reduce_fx="cat")
29+
self.add_state("target", default=[], dist_reduce_fx="cat")
30+
31+
def update(self, predictions: Tensor, targets: Tensor) -> None:
32+
"""Compute the PRO score for the current batch."""
33+
34+
self.target.append(targets)
35+
self.preds.append(predictions)
36+
37+
def compute(self) -> Tensor:
38+
"""Compute the macro average of the PRO score across all regions in all batches."""
39+
target = dim_zero_cat(self.target)
40+
preds = dim_zero_cat(self.preds)
41+
42+
if target.is_cuda:
43+
comps = connected_components_gpu(target.unsqueeze(1))
44+
else:
45+
comps = connected_components_cpu(target.unsqueeze(1))
46+
pro = pro_score(preds, comps, threshold=self.threshold)
47+
return pro
48+
49+
50+
def pro_score(predictions: Tensor, comps: Tensor, threshold: float = 0.5) -> Tensor:
51+
"""Calculate the PRO score for a batch of predictions.
52+
53+
Args:
54+
predictions (Tensor): Predicted anomaly masks (Bx1xHxW)
55+
comps: (Tensor): Labeled connected components (BxHxW). The components should be labeled from 0 to N
56+
threshold (float): When predictions are passed as float, the threshold is used to binarize the predictions.
57+
58+
Returns:
59+
Tensor: Scalar value representing the average PRO score for the input batch.
60+
"""
61+
if predictions.dtype == torch.float:
62+
predictions = predictions > threshold
63+
64+
n_comps = len(comps.unique())
65+
66+
preds = comps.clone()
67+
preds[~predictions] = 0
68+
if n_comps == 1: # only background
69+
return torch.Tensor([1.0])
70+
pro = recall(preds.flatten(), comps.flatten(), num_classes=n_comps, average="macro", ignore_index=0)
71+
return pro
72+
73+
74+
def connected_components_gpu(binary_input: Tensor, num_iterations: int = 1000) -> Tensor:
75+
"""Perform connected component labeling on GPU and remap the labels from 0 to N.
76+
77+
Args:
78+
binary_input (Tensor): Binary input data from which we want to extract connected components (Bx1xHxW)
79+
num_iterations (int): Number of iterations used in the connected component computation.
80+
81+
Returns:
82+
Tensor: Components labeled from 0 to N.
83+
"""
84+
components = connected_components(binary_input, num_iterations=num_iterations)
85+
86+
# remap component values from 0 to N
87+
labels = components.unique()
88+
for new_label, old_label in enumerate(labels):
89+
components[components == old_label] = new_label
90+
91+
return components.int()
92+
93+
94+
def connected_components_cpu(image: Tensor) -> Tensor:
95+
"""Connected component labeling on CPU.
96+
97+
Args:
98+
image (Tensor): Binary input data from which we want to extract connected components (Bx1xHxW)
99+
100+
Returns:
101+
Tensor: Components labeled from 0 to N.
102+
"""
103+
components = torch.zeros_like(image)
104+
label_idx = 1
105+
for i, mask in enumerate(image):
106+
mask = mask.squeeze().numpy().astype(np.uint8)
107+
_, comps = cv2.connectedComponents(mask)
108+
# remap component values to make sure every component has a unique value when outputs are concatenated
109+
for label in np.unique(comps)[1:]:
110+
components[i, 0, ...][np.where(comps == label)] = label_idx
111+
label_idx += 1
112+
return components.int()

tests/pre_merge/utils/metrics/test_aupro.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,17 @@ def pytest_generate_tests(metafunc):
1313
torch.tensor(
1414
[
1515
[
16-
[
17-
[0, 0, 0, 1, 0, 0, 0],
18-
]
19-
* 400,
16+
[0, 0, 0, 1, 0, 0, 0],
2017
]
18+
* 400,
2119
]
2220
),
2321
torch.tensor(
2422
[
2523
[
26-
[
27-
[0, 1, 0, 1, 0, 1, 0],
28-
]
29-
* 400,
24+
[0, 1, 0, 1, 0, 1, 0],
3025
]
26+
* 400,
3127
]
3228
),
3329
]
+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import torch
2+
from torch import Tensor
3+
from torchvision.transforms import RandomAffine
4+
5+
from anomalib.data.utils import random_2d_perlin
6+
from anomalib.utils.metrics.pro import (
7+
PRO,
8+
connected_components_cpu,
9+
connected_components_gpu,
10+
)
11+
12+
13+
def test_pro():
14+
"""Checks if PRO metric computes the (macro) average of the per-region overlap."""
15+
16+
labels = Tensor(
17+
[
18+
[
19+
[0, 0, 0, 0, 0],
20+
[1, 0, 0, 0, 0],
21+
[0, 0, 0, 0, 0],
22+
[1, 1, 0, 0, 0],
23+
[0, 0, 0, 0, 0],
24+
[1, 1, 1, 0, 0],
25+
[0, 0, 0, 0, 0],
26+
[1, 1, 1, 1, 0],
27+
[0, 0, 0, 0, 0],
28+
[1, 1, 1, 1, 1],
29+
]
30+
]
31+
)
32+
33+
preds = (torch.arange(10) / 10) + 0.05
34+
preds = preds.unsqueeze(1).repeat(1, 5).view(1, 1, 10, 5)
35+
36+
thresholds = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]
37+
targets = [1.0, 0.8, 0.6, 0.4, 0.2, 0.0]
38+
for threshold, target in zip(thresholds, targets):
39+
pro = PRO(threshold=threshold)
40+
pro.update(preds, labels)
41+
assert pro.compute() == target
42+
43+
44+
def test_device_consistency():
45+
"""Test if the pro metric yields the same results between cpu and gpu."""
46+
47+
transform = RandomAffine(5, None, (0.95, 1.05), 5)
48+
49+
batch = torch.zeros((32, 256, 256))
50+
for i in range(batch.shape[0]):
51+
batch[i, ...] = random_2d_perlin((256, 256), (torch.tensor(4), torch.tensor(4))) > 0.5
52+
53+
preds = transform(batch).unsqueeze(1)
54+
55+
pro_cpu = PRO()
56+
pro_gpu = PRO()
57+
58+
pro_cpu.update(preds.cpu(), batch.cpu())
59+
pro_gpu.update(preds.cuda(), batch.cuda())
60+
61+
assert torch.isclose(pro_cpu.compute(), pro_gpu.compute().cpu())
62+
63+
64+
def test_connected_component_labeling():
65+
"""Tests if the connected component labeling algorithms on cpu and gpu yield the same result."""
66+
67+
# generate batch of random binary images using perlin noise
68+
batch = torch.zeros((32, 1, 256, 256))
69+
for i in range(batch.shape[0]):
70+
batch[i, ...] = random_2d_perlin((256, 256), (torch.tensor(4), torch.tensor(4))) > 0.5
71+
72+
# get connected component results on both cpu and gpu
73+
cc_cpu = connected_components_cpu(batch.cpu())
74+
cc_gpu = connected_components_gpu(batch.cuda())
75+
76+
# check if comps are ordered from 0 to N
77+
assert len(cc_cpu.unique()) == cc_cpu.unique().max() + 1
78+
assert len(cc_gpu.unique()) == cc_gpu.unique().max() + 1
79+
# check if same number of comps found between cpu and gpu
80+
assert len(cc_cpu.unique()) == len(cc_gpu.unique())

0 commit comments

Comments
 (0)