13
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
14
# See the License for the specific language governing permissions
15
15
# and limitations under the License.
16
+ import random
16
17
17
18
import pytest
18
19
import torch
20
+ from pytorch_lightning import Trainer
19
21
22
+ from anomalib .config import get_configurable_parameters
23
+ from anomalib .data import get_datamodule
24
+ from anomalib .models import get_model
25
+ from anomalib .utils .callbacks import get_callbacks
20
26
from anomalib .utils .metrics import AdaptiveThreshold
21
27
22
28
@@ -35,3 +41,29 @@ def test_adaptive_threshold(labels, preds, target_threshold):
35
41
threshold_value = adaptive_threshold .compute ()
36
42
37
43
assert threshold_value == target_threshold
44
+
45
+
46
+ def test_non_adaptive_threshold ():
47
+ """
48
+ Test if the non-adaptive threshold gets used in the F1 score computation when
49
+ adaptive thresholding is disabled and no normalization is used.
50
+ """
51
+ config = get_configurable_parameters (model_config_path = "anomalib/models/padim/config.yaml" )
52
+
53
+ config .model .normalization_method = "none"
54
+ config .model .threshold .adaptive = False
55
+ config .trainer .fast_dev_run = True
56
+
57
+ image_threshold = random .random ()
58
+ pixel_threshold = random .random ()
59
+ config .model .threshold .image_default = image_threshold
60
+ config .model .threshold .pixel_default = pixel_threshold
61
+
62
+ model = get_model (config )
63
+ datamodule = get_datamodule (config )
64
+ callbacks = get_callbacks (config )
65
+
66
+ trainer = Trainer (** config .trainer , callbacks = callbacks )
67
+ trainer .fit (model = model , datamodule = datamodule )
68
+ assert trainer .model .image_metrics .F1 .threshold == image_threshold
69
+ assert trainer .model .pixel_metrics .F1 .threshold == pixel_threshold
0 commit comments