Skip to content

Commit 0d5d26d

Browse files
Bugfix: fix random val/test split issue (#48)
* fix random val/test split issue Co-authored-by: Samet <[email protected]>
1 parent 6eadef9 commit 0d5d26d

File tree

2 files changed

+43
-34
lines changed

2 files changed

+43
-34
lines changed

anomalib/data/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def get_datamodule(config: Union[DictConfig, ListConfig]):
4343
train_batch_size=config.dataset.train_batch_size,
4444
test_batch_size=config.dataset.test_batch_size,
4545
num_workers=config.dataset.num_workers,
46+
seed=config.project.seed,
4647
)
4748
else:
4849
raise ValueError("Unknown dataset!")

anomalib/data/mvtec.py

+42-34
Original file line numberDiff line numberDiff line change
@@ -199,10 +199,10 @@ def make_mvtec_dataset(
199199
samples.label_index = samples.label_index.astype(int)
200200

201201
if create_validation_set:
202-
samples = create_validation_set_from_test_set(samples)
202+
samples = create_validation_set_from_test_set(samples, seed=seed)
203203

204204
# Get the data frame for the split.
205-
if split is not None and split in ["train", "test"]:
205+
if split is not None and split in ["train", "val", "test"]:
206206
samples = samples[samples.split == split]
207207
samples = samples.reset_index(drop=True)
208208

@@ -217,19 +217,23 @@ def __init__(
217217
root: Union[Path, str],
218218
category: str,
219219
pre_process: PreProcessor,
220+
split: str,
220221
task: str = "segmentation",
221-
is_train: bool = True,
222222
download: bool = False,
223+
seed: int = 0,
224+
create_validation_set: bool = False,
223225
) -> None:
224226
"""Mvtec Dataset class.
225227
226228
Args:
227229
root: Path to the MVTec dataset
228230
category: Name of the MVTec category.
229231
pre_process: List of pre_processing object containing albumentation compose.
232+
split: 'train', 'val' or 'test'
230233
task: ``classification`` or ``segmentation``
231-
is_train: Boolean to check if the split is training
232234
download: Boolean to download the MVTec dataset.
235+
seed: seed used for the random subset splitting
236+
create_validation_set: Create a validation subset in addition to the train and test subsets
233237
234238
Examples:
235239
>>> from anomalib.data.mvtec import MVTec
@@ -264,15 +268,17 @@ def __init__(
264268
super().__init__(root)
265269
self.root = Path(root) if isinstance(root, str) else root
266270
self.category: str = category
267-
self.split = "train" if is_train else "test"
271+
self.split = split
268272
self.task = task
269273

270274
self.pre_process = pre_process
271275

272276
if download:
273277
self._download()
274278

275-
self.samples = make_mvtec_dataset(path=self.root / category, split=self.split)
279+
self.samples = make_mvtec_dataset(
280+
path=self.root / category, split=self.split, seed=seed, create_validation_set=create_validation_set
281+
)
276282

277283
def _download(self) -> None:
278284
"""Download the MVTec dataset."""
@@ -327,8 +333,7 @@ def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:
327333
if self.split == "train" or self.task == "classification":
328334
pre_processed = self.pre_process(image=image)
329335
item = {"image": pre_processed["image"]}
330-
331-
if self.split == "test":
336+
elif self.split in ["val", "test"]:
332337
label_index = self.samples.label_index[index]
333338

334339
item["image_path"] = image_path
@@ -366,6 +371,8 @@ def __init__(
366371
test_batch_size: int = 32,
367372
num_workers: int = 8,
368373
transform_config: Optional[Union[str, A.Compose]] = None,
374+
seed: int = 0,
375+
create_validation_set: bool = False,
369376
) -> None:
370377
"""Mvtec Lightning Data Module.
371378
@@ -377,6 +384,8 @@ def __init__(
377384
test_batch_size: Testing batch size.
378385
num_workers: Number of workers.
379386
transform_config: Config for pre-processing.
387+
seed: seed used for the random subset splitting
388+
create_validation_set: Create a validation subset in addition to the train and test subsets
380389
381390
Examples
382391
>>> from anomalib.data import MVTecDataModule
@@ -415,47 +424,45 @@ def __init__(
415424
self.test_batch_size = test_batch_size
416425
self.num_workers = num_workers
417426

418-
self.train_data: Dataset
419-
self.val_data: Dataset
427+
self.create_validation_set = create_validation_set
428+
self.seed = seed
420429

421-
def prepare_data(self):
422-
"""Prepare MVTec Dataset."""
423-
# Train
424-
MVTec(
425-
root=self.root,
426-
category=self.category,
427-
pre_process=self.pre_process,
428-
is_train=True,
429-
download=True,
430-
)
431-
432-
# Test
433-
MVTec(
434-
root=self.root,
435-
category=self.category,
436-
pre_process=self.pre_process,
437-
is_train=False,
438-
download=True,
439-
)
430+
self.train_data: Dataset
431+
self.test_data: Dataset
432+
if create_validation_set:
433+
self.val_data: Dataset
440434

441435
def setup(self, stage: Optional[str] = None) -> None:
442436
"""Setup train, validation and test data.
443437
444438
Args:
445439
stage: Optional[str]: Train/Val/Test stages. (Default value = None)
446440
"""
447-
self.val_data = MVTec(
441+
if self.create_validation_set:
442+
self.val_data = MVTec(
443+
root=self.root,
444+
category=self.category,
445+
pre_process=self.pre_process,
446+
split="val",
447+
seed=self.seed,
448+
create_validation_set=self.create_validation_set,
449+
)
450+
self.test_data = MVTec(
448451
root=self.root,
449452
category=self.category,
450453
pre_process=self.pre_process,
451-
is_train=False,
454+
split="test",
455+
seed=self.seed,
456+
create_validation_set=self.create_validation_set,
452457
)
453458
if stage in (None, "fit"):
454459
self.train_data = MVTec(
455460
root=self.root,
456461
category=self.category,
457462
pre_process=self.pre_process,
458-
is_train=True,
463+
split="train",
464+
seed=self.seed,
465+
create_validation_set=self.create_validation_set,
459466
)
460467

461468
def train_dataloader(self) -> DataLoader:
@@ -464,8 +471,9 @@ def train_dataloader(self) -> DataLoader:
464471

465472
def val_dataloader(self) -> DataLoader:
466473
"""Get validation dataloader."""
467-
return DataLoader(self.val_data, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers)
474+
dataset = self.val_data if self.create_validation_set else self.test_data
475+
return DataLoader(dataset=dataset, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers)
468476

469477
def test_dataloader(self) -> DataLoader:
470478
"""Get test dataloader."""
471-
return DataLoader(self.val_data, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers)
479+
return DataLoader(self.test_data, shuffle=False, batch_size=self.test_batch_size, num_workers=self.num_workers)

0 commit comments

Comments
 (0)