Skip to content

Commit b66e5e3

Browse files
authored
fix non-adaptive thresholding bug (#152)
1 parent 834d45a commit b66e5e3

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

anomalib/models/components/base/anomaly_module.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,12 @@ def __init__(self, params: Union[DictConfig, ListConfig]):
5858
self.model: nn.Module
5959

6060
# metrics
61-
auroc = AUROC(num_classes=1, pos_label=1, compute_on_step=False)
62-
f1_score = F1(num_classes=1, compute_on_step=False)
63-
self.image_metrics = MetricCollection([auroc, f1_score], prefix="image_").cpu()
64-
self.pixel_metrics = self.image_metrics.clone(prefix="pixel_").cpu()
61+
image_auroc = AUROC(num_classes=1, pos_label=1, compute_on_step=False)
62+
image_f1 = F1(num_classes=1, compute_on_step=False, threshold=self.hparams.model.threshold.image_default)
63+
pixel_auroc = AUROC(num_classes=1, pos_label=1, compute_on_step=False)
64+
pixel_f1 = F1(num_classes=1, compute_on_step=False, threshold=self.hparams.model.threshold.pixel_default)
65+
self.image_metrics = MetricCollection([image_auroc, image_f1], prefix="image_").cpu()
66+
self.pixel_metrics = MetricCollection([pixel_auroc, pixel_f1], prefix="pixel_").cpu()
6567

6668
def forward(self, batch): # pylint: disable=arguments-differ
6769
"""Forward-pass input tensor to the module.

tests/pre_merge/utils/metrics/test_adaptive_threshold.py

+32
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,16 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions
1515
# and limitations under the License.
16+
import random
1617

1718
import pytest
1819
import torch
20+
from pytorch_lightning import Trainer
1921

22+
from anomalib.config import get_configurable_parameters
23+
from anomalib.data import get_datamodule
24+
from anomalib.models import get_model
25+
from anomalib.utils.callbacks import get_callbacks
2026
from anomalib.utils.metrics import AdaptiveThreshold
2127

2228

@@ -35,3 +41,29 @@ def test_adaptive_threshold(labels, preds, target_threshold):
3541
threshold_value = adaptive_threshold.compute()
3642

3743
assert threshold_value == target_threshold
44+
45+
46+
def test_non_adaptive_threshold():
47+
"""
48+
Test if the non-adaptive threshold gets used in the F1 score computation when
49+
adaptive thresholding is disabled and no normalization is used.
50+
"""
51+
config = get_configurable_parameters(model_config_path="anomalib/models/padim/config.yaml")
52+
53+
config.model.normalization_method = "none"
54+
config.model.threshold.adaptive = False
55+
config.trainer.fast_dev_run = True
56+
57+
image_threshold = random.random()
58+
pixel_threshold = random.random()
59+
config.model.threshold.image_default = image_threshold
60+
config.model.threshold.pixel_default = pixel_threshold
61+
62+
model = get_model(config)
63+
datamodule = get_datamodule(config)
64+
callbacks = get_callbacks(config)
65+
66+
trainer = Trainer(**config.trainer, callbacks=callbacks)
67+
trainer.fit(model=model, datamodule=datamodule)
68+
assert trainer.model.image_metrics.F1.threshold == image_threshold
69+
assert trainer.model.pixel_metrics.F1.threshold == pixel_threshold

0 commit comments

Comments
 (0)