|
| 1 | +"""Implementation of PRO metric based on TorchMetrics.""" |
| 2 | + |
| 3 | +# Copyright (C) 2022 Intel Corporation |
| 4 | +# SPDX-License-Identifier: Apache-2.0 |
| 5 | + |
| 6 | +from typing import List |
| 7 | + |
| 8 | +import cv2 |
| 9 | +import numpy as np |
| 10 | +import torch |
| 11 | +from kornia.contrib import connected_components |
| 12 | +from torch import Tensor |
| 13 | +from torchmetrics import Metric |
| 14 | +from torchmetrics.functional import recall |
| 15 | +from torchmetrics.utilities.data import dim_zero_cat |
| 16 | + |
| 17 | + |
| 18 | +class PRO(Metric): |
| 19 | + """Per-Region Overlap (PRO) Score.""" |
| 20 | + |
| 21 | + target: List[Tensor] |
| 22 | + preds: List[Tensor] |
| 23 | + |
| 24 | + def __init__(self, threshold: float = 0.5, **kwargs) -> None: |
| 25 | + super().__init__(**kwargs) |
| 26 | + self.threshold = threshold |
| 27 | + |
| 28 | + self.add_state("preds", default=[], dist_reduce_fx="cat") |
| 29 | + self.add_state("target", default=[], dist_reduce_fx="cat") |
| 30 | + |
| 31 | + def update(self, predictions: Tensor, targets: Tensor) -> None: |
| 32 | + """Compute the PRO score for the current batch.""" |
| 33 | + |
| 34 | + self.target.append(targets) |
| 35 | + self.preds.append(predictions) |
| 36 | + |
| 37 | + def compute(self) -> Tensor: |
| 38 | + """Compute the macro average of the PRO score across all regions in all batches.""" |
| 39 | + target = dim_zero_cat(self.target) |
| 40 | + preds = dim_zero_cat(self.preds) |
| 41 | + |
| 42 | + if target.is_cuda: |
| 43 | + comps = connected_components_gpu(target.unsqueeze(1)) |
| 44 | + else: |
| 45 | + comps = connected_components_cpu(target.unsqueeze(1)) |
| 46 | + pro = pro_score(preds, comps, threshold=self.threshold) |
| 47 | + return pro |
| 48 | + |
| 49 | + |
| 50 | +def pro_score(predictions: Tensor, comps: Tensor, threshold: float = 0.5) -> Tensor: |
| 51 | + """Calculate the PRO score for a batch of predictions. |
| 52 | +
|
| 53 | + Args: |
| 54 | + predictions (Tensor): Predicted anomaly masks (Bx1xHxW) |
| 55 | + comps: (Tensor): Labeled connected components (BxHxW). The components should be labeled from 0 to N |
| 56 | + threshold (float): When predictions are passed as float, the threshold is used to binarize the predictions. |
| 57 | +
|
| 58 | + Returns: |
| 59 | + Tensor: Scalar value representing the average PRO score for the input batch. |
| 60 | + """ |
| 61 | + if predictions.dtype == torch.float: |
| 62 | + predictions = predictions > threshold |
| 63 | + |
| 64 | + n_comps = len(comps.unique()) |
| 65 | + |
| 66 | + preds = comps.clone() |
| 67 | + preds[~predictions] = 0 |
| 68 | + if n_comps == 1: # only background |
| 69 | + return torch.Tensor([1.0]) |
| 70 | + pro = recall(preds.flatten(), comps.flatten(), num_classes=n_comps, average="macro", ignore_index=0) |
| 71 | + return pro |
| 72 | + |
| 73 | + |
| 74 | +def connected_components_gpu(binary_input: Tensor, num_iterations: int = 1000) -> Tensor: |
| 75 | + """Perform connected component labeling on GPU and remap the labels from 0 to N. |
| 76 | +
|
| 77 | + Args: |
| 78 | + binary_input (Tensor): Binary input data from which we want to extract connected components (Bx1xHxW) |
| 79 | + num_iterations (int): Number of iterations used in the connected component computation. |
| 80 | +
|
| 81 | + Returns: |
| 82 | + Tensor: Components labeled from 0 to N. |
| 83 | + """ |
| 84 | + components = connected_components(binary_input, num_iterations=num_iterations) |
| 85 | + |
| 86 | + # remap component values from 0 to N |
| 87 | + labels = components.unique() |
| 88 | + for new_label, old_label in enumerate(labels): |
| 89 | + components[components == old_label] = new_label |
| 90 | + |
| 91 | + return components.int() |
| 92 | + |
| 93 | + |
| 94 | +def connected_components_cpu(image: Tensor) -> Tensor: |
| 95 | + """Connected component labeling on CPU. |
| 96 | +
|
| 97 | + Args: |
| 98 | + image (Tensor): Binary input data from which we want to extract connected components (Bx1xHxW) |
| 99 | +
|
| 100 | + Returns: |
| 101 | + Tensor: Components labeled from 0 to N. |
| 102 | + """ |
| 103 | + components = torch.zeros_like(image) |
| 104 | + label_idx = 1 |
| 105 | + for i, mask in enumerate(image): |
| 106 | + mask = mask.squeeze().numpy().astype(np.uint8) |
| 107 | + _, comps = cv2.connectedComponents(mask) |
| 108 | + # remap component values to make sure every component has a unique value when outputs are concatenated |
| 109 | + for label in np.unique(comps)[1:]: |
| 110 | + components[i, 0, ...][np.where(comps == label)] = label_idx |
| 111 | + label_idx += 1 |
| 112 | + return components.int() |
0 commit comments