Skip to content

Commit 3c39944

Browse files
committed
Dataclasses and post-processing refactor (open-edge-platform#2098)
* use dataclass for model in- and outputs * split dataclass in image and video * use dataclass in torch inferencer * use dataclass in openvino inferencer * add post_processor class * remove default metrics from CLI * export post processing * add post processor to patchcore * use named tuple for inference outputs * validate and format inputs of PredictBatch * update torch inference * remove base inferencer inheritance * update openvino inference * fix visualization * PredictBatch -> Batch * post processor as callback * use callback methods to apply post processing * temporary fix for visualization * add DatasetItem class * fix pred_score shape and add __len__ * make batch iterable * add in place replace method * use dataset items in inference * dataset_items -> items * use namedtuple as torch model outputs * formatting * split dataclasses into input/output and image/video * merge input and output classes * use init_subclass for attribute checking * add descriptor class for validation * improve error handling * DataClassDescriptor -> FieldDescriptor * add is_optional method * add input validation for torch image and batch * use image and video dataclasses in library * add more validation * add validation * make postprocessor configurable from engine * fix post processing logic * fix data tests * remove detection task type * fix more tests * use separate normalization stats for image and pixel preds * add sensitivity parameters to one class pp * fix utils tests * fix utils tests * remove metric serialization test * remove normalization and thresholding args * set default post processor in base model * remove manual threshold test * fix remaining unit tests * add post_processor to CLI args * remove old post processing callbacks * remove comment * remove references to old normalization and thresholding callbacks * remove reshape in openvino inferencer * export lightning model directly * make collate accessible from dataset * fix tools integration tests * add update method to dataclasses * allow missing pred_score or anomaly_map in post processor * fix exportable centercrop conversion * fix model tests * test all models * fix efficient_ad * post processor as model arg * disable rkde tests * fix winclip export * add copyright notice * add validation for numpy anomaly map * fix getting started notebook * remove hardcoded path * update dataset notebooks * update model notebooks * fix logging notebooks * fix model notebook
1 parent cdd338c commit 3c39944

File tree

96 files changed

+2457
-4876
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

96 files changed

+2457
-4876
lines changed

notebooks/000_getting_started/001_getting_started.ipynb

+204-150
Large diffs are not rendered by default.

notebooks/100_datamodules/101_btech.ipynb

+43-321
Large diffs are not rendered by default.

notebooks/100_datamodules/102_mvtec.ipynb

+33-297
Large diffs are not rendered by default.

notebooks/100_datamodules/103_folder.ipynb

+36-476
Large diffs are not rendered by default.

notebooks/100_datamodules/104_tiling.ipynb

+17-70
Large diffs are not rendered by default.

notebooks/200_models/201_fastflow.ipynb

+32-298
Large diffs are not rendered by default.

notebooks/600_loggers/601_mlflow_logging.ipynb

+27-1,374
Large diffs are not rendered by default.

src/anomalib/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,4 @@ class TaskType(str, Enum):
2020
"""Task type used when generating predictions on the dataset."""
2121

2222
CLASSIFICATION = "classification"
23-
DETECTION = "detection"
2423
SEGMENTATION = "segmentation"

src/anomalib/callbacks/metrics.py

+9-12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# SPDX-License-Identifier: Apache-2.0
55

66
import logging
7+
from dataclasses import asdict
78
from enum import Enum
89
from typing import Any
910

@@ -12,6 +13,7 @@
1213
from lightning.pytorch.utilities.types import STEP_OUTPUT
1314

1415
from anomalib import TaskType
16+
from anomalib.dataclasses import Batch
1517
from anomalib.metrics import AnomalibMetricCollection, create_metric_collection
1618
from anomalib.models import AnomalyModule
1719

@@ -96,7 +98,6 @@ def setup(
9698
pl_module.pixel_metrics.add_metrics(new_metrics[name])
9799
else:
98100
pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_")
99-
self._set_threshold(pl_module)
100101

101102
@staticmethod
102103
def on_validation_epoch_start(trainer: Trainer, pl_module: AnomalyModule) -> None:
@@ -117,13 +118,12 @@ def on_validation_batch_end(
117118
del trainer, batch, batch_idx, dataloader_idx # Unused arguments.
118119

119120
if outputs is not None:
120-
self._outputs_to_device(outputs)
121+
outputs = self._outputs_to_device(outputs)
121122
self._update_metrics(pl_module.image_metrics, pl_module.pixel_metrics, outputs)
122123

123124
def on_validation_epoch_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
124125
del trainer # Unused argument.
125126

126-
self._set_threshold(pl_module)
127127
self._log_metrics(pl_module)
128128

129129
@staticmethod
@@ -145,35 +145,32 @@ def on_test_batch_end(
145145
del trainer, batch, batch_idx, dataloader_idx # Unused arguments.
146146

147147
if outputs is not None:
148-
self._outputs_to_device(outputs)
148+
outputs = self._outputs_to_device(outputs)
149149
self._update_metrics(pl_module.image_metrics, pl_module.pixel_metrics, outputs)
150150

151151
def on_test_epoch_end(self, trainer: Trainer, pl_module: AnomalyModule) -> None:
152152
del trainer # Unused argument.
153153

154154
self._log_metrics(pl_module)
155155

156-
@staticmethod
157-
def _set_threshold(pl_module: AnomalyModule) -> None:
158-
pl_module.image_metrics.set_threshold(pl_module.image_threshold.value.item())
159-
pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item())
160-
161156
def _update_metrics(
162157
self,
163158
image_metric: AnomalibMetricCollection,
164159
pixel_metric: AnomalibMetricCollection,
165160
output: STEP_OUTPUT,
166161
) -> None:
167162
image_metric.to(self.device)
168-
image_metric.update(output["pred_scores"], output["label"].int())
169-
if "mask" in output and "anomaly_maps" in output:
163+
image_metric.update(output.pred_score, output.gt_label.int())
164+
if output.gt_mask is not None and output.anomaly_map is not None:
170165
pixel_metric.to(self.device)
171-
pixel_metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int()))
166+
pixel_metric.update(torch.squeeze(output.anomaly_map), torch.squeeze(output.gt_mask.int()))
172167

173168
def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]:
174169
if isinstance(output, dict):
175170
for key, value in output.items():
176171
output[key] = self._outputs_to_device(value)
172+
elif isinstance(output, Batch):
173+
output = output.__class__(**self._outputs_to_device(asdict(output)))
177174
elif isinstance(output, torch.Tensor):
178175
output = output.to(self.device)
179176
return output

src/anomalib/callbacks/normalization/__init__.py

-12
This file was deleted.

src/anomalib/callbacks/normalization/base.py

-29
This file was deleted.

src/anomalib/callbacks/normalization/min_max_normalization.py

-131
This file was deleted.

src/anomalib/callbacks/normalization/utils.py

-78
This file was deleted.

0 commit comments

Comments
 (0)