Skip to content

Commit 2f9fa13

Browse files
ashwinvaidya17Ashwin Vaidya
and
Ashwin Vaidya
authored
Support null seed (#437)
* Support null seed * Set optional value to None Co-authored-by: Ashwin Vaidya <[email protected]>
1 parent 8f437b7 commit 2f9fa13

File tree

7 files changed

+77
-16
lines changed

7 files changed

+77
-16
lines changed

anomalib/data/btech.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import logging
2424
import shutil
25+
import warnings
2526
import zipfile
2627
from pathlib import Path
2728
from typing import Dict, Optional, Tuple, Union
@@ -56,7 +57,7 @@ def make_btech_dataset(
5657
path: Path,
5758
split: Optional[str] = None,
5859
split_ratio: float = 0.1,
59-
seed: int = 0,
60+
seed: Optional[int] = None,
6061
create_validation_set: bool = False,
6162
) -> DataFrame:
6263
"""Create BTech samples by parsing the BTech data file structure.
@@ -152,7 +153,7 @@ def __init__(
152153
pre_process: PreProcessor,
153154
split: str,
154155
task: str = "segmentation",
155-
seed: int = 0,
156+
seed: Optional[int] = None,
156157
create_validation_set: bool = False,
157158
) -> None:
158159
"""Btech Dataset class.
@@ -197,6 +198,14 @@ def __init__(
197198
(torch.Size([3, 256, 256]), torch.Size([256, 256]))
198199
"""
199200
super().__init__(root)
201+
202+
if seed is None:
203+
warnings.warn(
204+
"seed is None."
205+
" When seed is not set, images from the normal directory are split between training and test dir."
206+
" This will lead to inconsistency between runs."
207+
)
208+
200209
self.root = Path(root) if isinstance(root, str) else root
201210
self.category: str = category
202211
self.split = split
@@ -274,7 +283,7 @@ def __init__(
274283
task: str = "segmentation",
275284
transform_config_train: Optional[Union[str, A.Compose]] = None,
276285
transform_config_val: Optional[Union[str, A.Compose]] = None,
277-
seed: int = 0,
286+
seed: Optional[int] = None,
278287
create_validation_set: bool = False,
279288
) -> None:
280289
"""Instantiate BTech Lightning Data Module.

anomalib/data/folder.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def make_dataset(
9595
mask_dir: Optional[Union[str, Path]] = None,
9696
split: Optional[str] = None,
9797
split_ratio: float = 0.2,
98-
seed: int = 0,
98+
seed: Optional[int] = None,
9999
create_validation_set: bool = True,
100100
extensions: Optional[Tuple[str, ...]] = None,
101101
):
@@ -191,7 +191,7 @@ def __init__(
191191
mask_dir: Optional[Union[Path, str]] = None,
192192
extensions: Optional[Tuple[str, ...]] = None,
193193
task: Optional[str] = None,
194-
seed: int = 0,
194+
seed: Optional[int] = None,
195195
create_validation_set: bool = False,
196196
) -> None:
197197
"""Create Folder Folder Dataset.
@@ -316,7 +316,7 @@ def __init__(
316316
mask_dir: Optional[Union[Path, str]] = None,
317317
extensions: Optional[Tuple[str, ...]] = None,
318318
split_ratio: float = 0.2,
319-
seed: int = 0,
319+
seed: Optional[int] = None,
320320
image_size: Optional[Union[int, Tuple[int, int]]] = None,
321321
train_batch_size: int = 32,
322322
test_batch_size: int = 32,
@@ -425,6 +425,13 @@ def __init__(
425425
"""
426426
super().__init__()
427427

428+
if seed is None and normal_test_dir is None:
429+
raise ValueError(
430+
"Both seed and normal_test_dir cannot be None."
431+
" When seed is not set, images from the normal directory are split between training and test dir."
432+
" This will lead to inconsistency between runs."
433+
)
434+
428435
self.root = _check_and_convert_path(root)
429436
self.normal_dir = self.root / normal_dir
430437
self.abnormal_dir = self.root / abnormal_dir

anomalib/data/mvtec.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
import logging
4242
import tarfile
43+
import warnings
4344
from pathlib import Path
4445
from typing import Dict, Optional, Tuple, Union
4546
from urllib.request import urlretrieve
@@ -72,7 +73,7 @@ def make_mvtec_dataset(
7273
path: Path,
7374
split: Optional[str] = None,
7475
split_ratio: float = 0.1,
75-
seed: int = 0,
76+
seed: Optional[int] = None,
7677
create_validation_set: bool = False,
7778
) -> DataFrame:
7879
"""Create MVTec AD samples by parsing the MVTec AD data file structure.
@@ -175,7 +176,7 @@ def __init__(
175176
pre_process: PreProcessor,
176177
split: str,
177178
task: str = "segmentation",
178-
seed: int = 0,
179+
seed: Optional[int] = None,
179180
create_validation_set: bool = False,
180181
) -> None:
181182
"""Mvtec AD Dataset class.
@@ -220,6 +221,14 @@ def __init__(
220221
(torch.Size([3, 256, 256]), torch.Size([256, 256]))
221222
"""
222223
super().__init__(root)
224+
225+
if seed is None:
226+
warnings.warn(
227+
"seed is None."
228+
" When seed is not set, images from the normal directory are split between training and test dir."
229+
" This will lead to inconsistency between runs."
230+
)
231+
223232
self.root = Path(root) if isinstance(root, str) else root
224233
self.category: str = category
225234
self.split = split
@@ -297,7 +306,7 @@ def __init__(
297306
task: str = "segmentation",
298307
transform_config_train: Optional[Union[str, A.Compose]] = None,
299308
transform_config_val: Optional[Union[str, A.Compose]] = None,
300-
seed: int = 0,
309+
seed: Optional[int] = None,
301310
create_validation_set: bool = False,
302311
) -> None:
303312
"""Mvtec AD Lightning Data Module.

anomalib/data/utils/split.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@
2323
# and limitations under the License.
2424

2525
import random
26+
from typing import Optional
2627

2728
from pandas.core.frame import DataFrame
2829

2930

3031
def split_normal_images_in_train_set(
31-
samples: DataFrame, split_ratio: float = 0.1, seed: int = 0, normal_label: str = "good"
32+
samples: DataFrame, split_ratio: float = 0.1, seed: Optional[int] = None, normal_label: str = "good"
3233
) -> DataFrame:
3334
"""Split normal images in train set.
3435
@@ -49,7 +50,7 @@ def split_normal_images_in_train_set(
4950
DataFrame: Output dataframe where the part of the training set is assigned to test set.
5051
"""
5152

52-
if seed >= 0:
53+
if seed:
5354
random.seed(seed)
5455

5556
normal_train_image_indices = samples.index[(samples.split == "train") & (samples.label == normal_label)].to_list()
@@ -62,7 +63,9 @@ def split_normal_images_in_train_set(
6263
return samples
6364

6465

65-
def create_validation_set_from_test_set(samples: DataFrame, seed: int = 0, normal_label: str = "good") -> DataFrame:
66+
def create_validation_set_from_test_set(
67+
samples: DataFrame, seed: Optional[int] = None, normal_label: str = "good"
68+
) -> DataFrame:
6669
"""Craete Validation Set from Test Set.
6770
6871
This function creates a validation set from test set by splitting both
@@ -74,7 +77,7 @@ def create_validation_set_from_test_set(samples: DataFrame, seed: int = 0, norma
7477
normal_label (str): Name of the normal label. For MVTec AD, for instance, this is normal_label.
7578
"""
7679

77-
if seed >= 0:
80+
if seed:
7881
random.seed(seed)
7982

8083
# Split normal images.

tests/helpers/dataset.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def __init__(
177177
max_size: Optional[int] = 10,
178178
train_shapes: List[str] = ["triangle", "rectangle"],
179179
test_shapes: List[str] = ["star", "hexagon"],
180-
seed: int = 0,
180+
seed: Optional[int] = None,
181181
) -> None:
182182
self.root_dir = mkdtemp()
183183
self.num_train = num_train
@@ -244,7 +244,7 @@ def _generate_dataset(self):
244244

245245
def __enter__(self):
246246
"""Creates the dataset in temp folder."""
247-
if self.seed > 0:
247+
if self.seed:
248248
np.random.seed(self.seed)
249249
self._generate_dataset()
250250
return self.root_dir

tests/pre_merge/datasets/test_dataset.py

+33
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,17 @@ def test_val_and_test_dataloaders_has_mask_and_gt(self, mvtec_data_module):
9292
assert sorted(["image_path", "mask_path", "image", "label", "mask"]) == sorted(val_data.keys())
9393
assert sorted(["image_path", "mask_path", "image", "label", "mask"]) == sorted(test_data.keys())
9494

95+
def test_non_overlapping_splits(self, mvtec_data_module):
96+
"""This test ensures that the train and test splits generated are non-overlapping."""
97+
assert (
98+
len(
99+
set(mvtec_data_module.test_data.samples["image_path"].values).intersection(
100+
set(mvtec_data_module.train_data.samples["image_path"].values)
101+
)
102+
)
103+
== 0
104+
), "Found train and test split contamination"
105+
95106

96107
class TestBTechDataModule:
97108
"""Test BTech Data Module."""
@@ -111,6 +122,17 @@ def test_val_and_test_dataloaders_has_mask_and_gt(self, btech_data_module):
111122
assert sorted(["image_path", "mask_path", "image", "label", "mask"]) == sorted(val_data.keys())
112123
assert sorted(["image_path", "mask_path", "image", "label", "mask"]) == sorted(test_data.keys())
113124

125+
def test_non_overlapping_splits(self, btech_data_module):
126+
"""This test ensures that the train and test splits generated are non-overlapping."""
127+
assert (
128+
len(
129+
set(btech_data_module.test_data.samples["image_path"].values).intersection(
130+
set(btech_data_module.train_data.samples["image_path"].values)
131+
)
132+
)
133+
== 0
134+
), "Found train and test split contamination"
135+
114136

115137
class TestFolderDataModule:
116138
"""Test Folder Data Module."""
@@ -130,6 +152,17 @@ def test_val_and_test_dataloaders_has_mask_and_gt(self, folder_data_module):
130152
assert sorted(["image_path", "mask_path", "image", "label", "mask"]) == sorted(val_data.keys())
131153
assert sorted(["image_path", "mask_path", "image", "label", "mask"]) == sorted(test_data.keys())
132154

155+
def test_non_overlapping_splits(self, folder_data_module):
156+
"""This test ensures that the train and test splits generated are non-overlapping."""
157+
assert (
158+
len(
159+
set(folder_data_module.test_data.samples["image_path"].values).intersection(
160+
set(folder_data_module.train_data.samples["image_path"].values)
161+
)
162+
)
163+
== 0
164+
), "Found train and test split contamination"
165+
133166

134167
class TestDenormalize:
135168
"""Test Denormalize Util."""

tools/train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def train():
5858
warnings.filterwarnings("ignore")
5959

6060
config = get_configurable_parameters(model_name=args.model, config_path=args.config)
61-
if config.project.seed != 0:
61+
if config.project.seed:
6262
seed_everything(config.project.seed)
6363

6464
datamodule = get_datamodule(config)

0 commit comments

Comments
 (0)