@@ -199,10 +199,10 @@ def make_mvtec_dataset(
199
199
samples .label_index = samples .label_index .astype (int )
200
200
201
201
if create_validation_set :
202
- samples = create_validation_set_from_test_set (samples )
202
+ samples = create_validation_set_from_test_set (samples , seed = seed )
203
203
204
204
# Get the data frame for the split.
205
- if split is not None and split in ["train" , "test" ]:
205
+ if split is not None and split in ["train" , "val" , " test" ]:
206
206
samples = samples [samples .split == split ]
207
207
samples = samples .reset_index (drop = True )
208
208
@@ -217,19 +217,23 @@ def __init__(
217
217
root : Union [Path , str ],
218
218
category : str ,
219
219
pre_process : PreProcessor ,
220
+ split : str ,
220
221
task : str = "segmentation" ,
221
- is_train : bool = True ,
222
222
download : bool = False ,
223
+ seed : int = 0 ,
224
+ create_validation_set : bool = False ,
223
225
) -> None :
224
226
"""Mvtec Dataset class.
225
227
226
228
Args:
227
229
root: Path to the MVTec dataset
228
230
category: Name of the MVTec category.
229
231
pre_process: List of pre_processing object containing albumentation compose.
232
+ split: 'train', 'val' or 'test'
230
233
task: ``classification`` or ``segmentation``
231
- is_train: Boolean to check if the split is training
232
234
download: Boolean to download the MVTec dataset.
235
+ seed: seed used for the random subset splitting
236
+ create_validation_set: Create a validation subset in addition to the train and test subsets
233
237
234
238
Examples:
235
239
>>> from anomalib.data.mvtec import MVTec
@@ -264,15 +268,17 @@ def __init__(
264
268
super ().__init__ (root )
265
269
self .root = Path (root ) if isinstance (root , str ) else root
266
270
self .category : str = category
267
- self .split = "train" if is_train else "test"
271
+ self .split = split
268
272
self .task = task
269
273
270
274
self .pre_process = pre_process
271
275
272
276
if download :
273
277
self ._download ()
274
278
275
- self .samples = make_mvtec_dataset (path = self .root / category , split = self .split )
279
+ self .samples = make_mvtec_dataset (
280
+ path = self .root / category , split = self .split , seed = seed , create_validation_set = create_validation_set
281
+ )
276
282
277
283
def _download (self ) -> None :
278
284
"""Download the MVTec dataset."""
@@ -327,8 +333,7 @@ def __getitem__(self, index: int) -> Dict[str, Union[str, Tensor]]:
327
333
if self .split == "train" or self .task == "classification" :
328
334
pre_processed = self .pre_process (image = image )
329
335
item = {"image" : pre_processed ["image" ]}
330
-
331
- if self .split == "test" :
336
+ elif self .split in ["val" , "test" ]:
332
337
label_index = self .samples .label_index [index ]
333
338
334
339
item ["image_path" ] = image_path
@@ -366,6 +371,8 @@ def __init__(
366
371
test_batch_size : int = 32 ,
367
372
num_workers : int = 8 ,
368
373
transform_config : Optional [Union [str , A .Compose ]] = None ,
374
+ seed : int = 0 ,
375
+ create_validation_set : bool = False ,
369
376
) -> None :
370
377
"""Mvtec Lightning Data Module.
371
378
@@ -377,6 +384,8 @@ def __init__(
377
384
test_batch_size: Testing batch size.
378
385
num_workers: Number of workers.
379
386
transform_config: Config for pre-processing.
387
+ seed: seed used for the random subset splitting
388
+ create_validation_set: Create a validation subset in addition to the train and test subsets
380
389
381
390
Examples
382
391
>>> from anomalib.data import MVTecDataModule
@@ -415,47 +424,45 @@ def __init__(
415
424
self .test_batch_size = test_batch_size
416
425
self .num_workers = num_workers
417
426
418
- self .train_data : Dataset
419
- self .val_data : Dataset
427
+ self .create_validation_set = create_validation_set
428
+ self .seed = seed
420
429
421
- def prepare_data (self ):
422
- """Prepare MVTec Dataset."""
423
- # Train
424
- MVTec (
425
- root = self .root ,
426
- category = self .category ,
427
- pre_process = self .pre_process ,
428
- is_train = True ,
429
- download = True ,
430
- )
431
-
432
- # Test
433
- MVTec (
434
- root = self .root ,
435
- category = self .category ,
436
- pre_process = self .pre_process ,
437
- is_train = False ,
438
- download = True ,
439
- )
430
+ self .train_data : Dataset
431
+ self .test_data : Dataset
432
+ if create_validation_set :
433
+ self .val_data : Dataset
440
434
441
435
def setup (self , stage : Optional [str ] = None ) -> None :
442
436
"""Setup train, validation and test data.
443
437
444
438
Args:
445
439
stage: Optional[str]: Train/Val/Test stages. (Default value = None)
446
440
"""
447
- self .val_data = MVTec (
441
+ if self .create_validation_set :
442
+ self .val_data = MVTec (
443
+ root = self .root ,
444
+ category = self .category ,
445
+ pre_process = self .pre_process ,
446
+ split = "val" ,
447
+ seed = self .seed ,
448
+ create_validation_set = self .create_validation_set ,
449
+ )
450
+ self .test_data = MVTec (
448
451
root = self .root ,
449
452
category = self .category ,
450
453
pre_process = self .pre_process ,
451
- is_train = False ,
454
+ split = "test" ,
455
+ seed = self .seed ,
456
+ create_validation_set = self .create_validation_set ,
452
457
)
453
458
if stage in (None , "fit" ):
454
459
self .train_data = MVTec (
455
460
root = self .root ,
456
461
category = self .category ,
457
462
pre_process = self .pre_process ,
458
- is_train = True ,
463
+ split = "train" ,
464
+ seed = self .seed ,
465
+ create_validation_set = self .create_validation_set ,
459
466
)
460
467
461
468
def train_dataloader (self ) -> DataLoader :
@@ -464,8 +471,9 @@ def train_dataloader(self) -> DataLoader:
464
471
465
472
def val_dataloader (self ) -> DataLoader :
466
473
"""Get validation dataloader."""
467
- return DataLoader (self .val_data , shuffle = False , batch_size = self .test_batch_size , num_workers = self .num_workers )
474
+ dataset = self .val_data if self .create_validation_set else self .test_data
475
+ return DataLoader (dataset = dataset , shuffle = False , batch_size = self .test_batch_size , num_workers = self .num_workers )
468
476
469
477
def test_dataloader (self ) -> DataLoader :
470
478
"""Get test dataloader."""
471
- return DataLoader (self .val_data , shuffle = False , batch_size = self .test_batch_size , num_workers = self .num_workers )
479
+ return DataLoader (self .test_data , shuffle = False , batch_size = self .test_batch_size , num_workers = self .num_workers )
0 commit comments