@@ -43,6 +43,8 @@ class EfficientAd(AnomalyModule):
43
43
"""PL Lightning Module for the EfficientAd algorithm.
44
44
45
45
Args:
46
+ imagenet_dir (Path|str): directory path for the Imagenet dataset
47
+ Defaults to ``./datasets/imagenette``.
46
48
teacher_out_channels (int): number of convolution output channels
47
49
Defaults to ``384``.
48
50
model_size (str): size of student and teacher model
@@ -62,6 +64,7 @@ class EfficientAd(AnomalyModule):
62
64
63
65
def __init__ (
64
66
self ,
67
+ imagenet_dir : Path | str = "./datasets/imagenette" ,
65
68
teacher_out_channels : int = 384 ,
66
69
model_size : EfficientAdModelSize = EfficientAdModelSize .S ,
67
70
lr : float = 0.0001 ,
@@ -72,6 +75,7 @@ def __init__(
72
75
) -> None :
73
76
super ().__init__ ()
74
77
78
+ self .imagenet_dir = Path (imagenet_dir )
75
79
self .model_size = model_size
76
80
self .model : EfficientAdModel = EfficientAdModel (
77
81
teacher_out_channels = teacher_out_channels ,
@@ -109,10 +113,9 @@ def prepare_imagenette_data(self, image_size: tuple[int, int] | torch.Size) -> N
109
113
],
110
114
)
111
115
112
- imagenet_dir = Path ("./datasets/imagenette" )
113
- if not imagenet_dir .is_dir ():
114
- download_and_extract (imagenet_dir , IMAGENETTE_DOWNLOAD_INFO )
115
- imagenet_dataset = ImageFolder (imagenet_dir , transform = self .data_transforms_imagenet )
116
+ if not self .imagenet_dir .is_dir ():
117
+ download_and_extract (self .imagenet_dir , IMAGENETTE_DOWNLOAD_INFO )
118
+ imagenet_dataset = ImageFolder (self .imagenet_dir , transform = self .data_transforms_imagenet )
116
119
self .imagenet_loader = DataLoader (imagenet_dataset , batch_size = self .batch_size , shuffle = True , pin_memory = True )
117
120
self .imagenet_iterator = iter (self .imagenet_loader )
118
121
0 commit comments