Skip to content

Commit 165702f

Browse files
Made-imagenette-path-configurable-in-config (#1833)
Signed-off-by: sahusiddharth <[email protected]> Co-authored-by: Samet Akcay <[email protected]>
1 parent 8c6e607 commit 165702f

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/anomalib/models/image/efficient_ad/lightning_model.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class EfficientAd(AnomalyModule):
4343
"""PL Lightning Module for the EfficientAd algorithm.
4444
4545
Args:
46+
imagenet_dir (Path|str): directory path for the Imagenet dataset
47+
Defaults to ``./datasets/imagenette``.
4648
teacher_out_channels (int): number of convolution output channels
4749
Defaults to ``384``.
4850
model_size (str): size of student and teacher model
@@ -62,6 +64,7 @@ class EfficientAd(AnomalyModule):
6264

6365
def __init__(
6466
self,
67+
imagenet_dir: Path | str = "./datasets/imagenette",
6568
teacher_out_channels: int = 384,
6669
model_size: EfficientAdModelSize = EfficientAdModelSize.S,
6770
lr: float = 0.0001,
@@ -72,6 +75,7 @@ def __init__(
7275
) -> None:
7376
super().__init__()
7477

78+
self.imagenet_dir = Path(imagenet_dir)
7579
self.model_size = model_size
7680
self.model: EfficientAdModel = EfficientAdModel(
7781
teacher_out_channels=teacher_out_channels,
@@ -109,10 +113,9 @@ def prepare_imagenette_data(self, image_size: tuple[int, int] | torch.Size) -> N
109113
],
110114
)
111115

112-
imagenet_dir = Path("./datasets/imagenette")
113-
if not imagenet_dir.is_dir():
114-
download_and_extract(imagenet_dir, IMAGENETTE_DOWNLOAD_INFO)
115-
imagenet_dataset = ImageFolder(imagenet_dir, transform=self.data_transforms_imagenet)
116+
if not self.imagenet_dir.is_dir():
117+
download_and_extract(self.imagenet_dir, IMAGENETTE_DOWNLOAD_INFO)
118+
imagenet_dataset = ImageFolder(self.imagenet_dir, transform=self.data_transforms_imagenet)
116119
self.imagenet_loader = DataLoader(imagenet_dataset, batch_size=self.batch_size, shuffle=True, pin_memory=True)
117120
self.imagenet_iterator = iter(self.imagenet_loader)
118121

0 commit comments

Comments
 (0)