-
Notifications
You must be signed in to change notification settings - Fork 737
Dataclasses and post-processing refactor #2098
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
68c5582
ddfcd5f
32e038d
675dd3f
5779ab7
0662558
fddbeb1
e32bd7d
90265e8
e3a9c1d
89f972c
08bdae2
2bc76fc
f7c7f9a
4160ab3
fd9eb24
87facb6
2269a78
9652b9f
082bbbc
dbabb20
b190cd3
ed904eb
773e54a
f8d999a
67046dd
d00b938
9fb4549
86cf632
fa3b874
2761600
c650dfc
ced34ca
12cd32d
b447cab
213c2b4
fb80feb
d2337a7
b53f1f7
5f16147
9203318
e99d630
631ba97
b750042
b37e265
86a365d
fcbb628
0fc3337
7ec9dd7
f5a48cd
afaec9b
e0a70c8
211d9f8
eb584eb
442c37f
bd59184
e17eda5
3140e8b
af99bed
039be2a
381e638
daead5b
987abe5
a709c6c
a37fa3b
14da4fa
beb3b97
a9d07db
6bcca36
014cb59
58df063
25845fb
1defdba
8d60276
0afb6d9
a26efb9
e7d9852
a4bcbfe
085c4aa
eff1f97
40bb4be
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -4,6 +4,7 @@ | |||||
# SPDX-License-Identifier: Apache-2.0 | ||||||
|
||||||
import logging | ||||||
from dataclasses import asdict | ||||||
from enum import Enum | ||||||
from typing import Any | ||||||
|
||||||
|
@@ -12,6 +13,7 @@ | |||||
from lightning.pytorch.utilities.types import STEP_OUTPUT | ||||||
|
||||||
from anomalib import TaskType | ||||||
from anomalib.dataclasses import Batch | ||||||
from anomalib.metrics import AnomalibMetricCollection, create_metric_collection | ||||||
from anomalib.models import AnomalyModule | ||||||
|
||||||
|
@@ -96,7 +98,6 @@ def setup( | |||||
pl_module.pixel_metrics.add_metrics(new_metrics[name]) | ||||||
else: | ||||||
pl_module.pixel_metrics = create_metric_collection(pixel_metric_names, "pixel_") | ||||||
self._set_threshold(pl_module) | ||||||
|
||||||
def on_validation_epoch_start( | ||||||
self, | ||||||
|
@@ -120,7 +121,7 @@ def on_validation_batch_end( | |||||
del trainer, batch, batch_idx, dataloader_idx # Unused arguments. | ||||||
|
||||||
if outputs is not None: | ||||||
self._outputs_to_device(outputs) | ||||||
outputs = self._outputs_to_device(outputs) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is for a future reference... I hope to get rid of this device related stuff, and leave it to Lightning |
||||||
self._update_metrics(pl_module.image_metrics, pl_module.pixel_metrics, outputs) | ||||||
|
||||||
def on_validation_epoch_end( | ||||||
|
@@ -130,7 +131,6 @@ def on_validation_epoch_end( | |||||
) -> None: | ||||||
del trainer # Unused argument. | ||||||
|
||||||
self._set_threshold(pl_module) | ||||||
self._log_metrics(pl_module) | ||||||
|
||||||
def on_test_epoch_start( | ||||||
|
@@ -155,7 +155,7 @@ def on_test_batch_end( | |||||
del trainer, batch, batch_idx, dataloader_idx # Unused arguments. | ||||||
|
||||||
if outputs is not None: | ||||||
self._outputs_to_device(outputs) | ||||||
outputs = self._outputs_to_device(outputs) | ||||||
self._update_metrics(pl_module.image_metrics, pl_module.pixel_metrics, outputs) | ||||||
|
||||||
def on_test_epoch_end( | ||||||
|
@@ -167,26 +167,24 @@ def on_test_epoch_end( | |||||
|
||||||
self._log_metrics(pl_module) | ||||||
|
||||||
def _set_threshold(self, pl_module: AnomalyModule) -> None: | ||||||
pl_module.image_metrics.set_threshold(pl_module.image_threshold.value.item()) | ||||||
pl_module.pixel_metrics.set_threshold(pl_module.pixel_threshold.value.item()) | ||||||
|
||||||
def _update_metrics( | ||||||
self, | ||||||
image_metric: AnomalibMetricCollection, | ||||||
pixel_metric: AnomalibMetricCollection, | ||||||
output: STEP_OUTPUT, | ||||||
) -> None: | ||||||
image_metric.to(self.device) | ||||||
image_metric.update(output["pred_scores"], output["label"].int()) | ||||||
if "mask" in output and "anomaly_maps" in output: | ||||||
image_metric.update(output.pred_score, output.gt_label.int()) | ||||||
if output.gt_mask is not None and output.anomaly_map is not None: | ||||||
pixel_metric.to(self.device) | ||||||
pixel_metric.update(torch.squeeze(output["anomaly_maps"]), torch.squeeze(output["mask"].int())) | ||||||
pixel_metric.update(torch.squeeze(output.anomaly_map), torch.squeeze(output.gt_mask.int())) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
def _outputs_to_device(self, output: STEP_OUTPUT) -> STEP_OUTPUT | dict[str, Any]: | ||||||
if isinstance(output, dict): | ||||||
for key, value in output.items(): | ||||||
output[key] = self._outputs_to_device(value) | ||||||
elif isinstance(output, Batch): | ||||||
output = output.__class__(**self._outputs_to_device(asdict(output))) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would it be an idea to add a comment here? It might be difficult to understand for some readers |
||||||
elif isinstance(output, torch.Tensor): | ||||||
output = output.to(self.device) | ||||||
return output | ||||||
|
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure if this is the best place to store these objects