Skip to content

Commit 6dc28f8

Browse files
authored
Improved speed and memory usage of mean+std calculation (#1457)
1 parent 66532fc commit 6dc28f8

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

src/anomalib/models/efficient_ad/lightning_model.py

+22-12
Original file line numberDiff line numberDiff line change
@@ -125,30 +125,40 @@ def prepare_imagenette_data(self) -> None:
125125
@torch.no_grad()
126126
def teacher_channel_mean_std(self, dataloader: DataLoader) -> dict[str, Tensor]:
127127
"""Calculate the mean and std of the teacher models activations.
128+
Adapted from https://math.stackexchange.com/a/2148949
128129
129130
Args:
130131
dataloader (DataLoader): Dataloader of the respective dataset.
131132
132133
Returns:
133134
dict[str, Tensor]: Dictionary of channel-wise mean and std
134135
"""
135-
y_means = []
136-
means_distance = []
137136

138-
logger.info("Calculate teacher channel mean and std")
139-
for batch in tqdm.tqdm(dataloader, desc="Calculate teacher channel mean", position=0, leave=True):
137+
arrays_defined = False
138+
n: torch.Tensor | None = None
139+
chanel_sum: torch.Tensor | None = None
140+
chanel_sum_sqr: torch.Tensor | None = None
141+
142+
for batch in tqdm.tqdm(dataloader, desc="Calculate teacher channel mean & std", position=0, leave=True):
140143
y = self.model.teacher(batch["image"].to(self.device))
141-
y_means.append(torch.mean(y, dim=[0, 2, 3]))
144+
if not arrays_defined:
145+
_, num_channels, _, _ = y.shape
146+
n = torch.zeros((num_channels,), dtype=torch.int64, device=y.device)
147+
chanel_sum = torch.zeros((num_channels,), dtype=torch.float64, device=y.device)
148+
chanel_sum_sqr = torch.zeros((num_channels,), dtype=torch.float64, device=y.device)
149+
arrays_defined = True
142150

143-
channel_mean = torch.mean(torch.stack(y_means), dim=0)[None, :, None, None]
151+
n += y[:, 0].numel()
152+
chanel_sum += torch.sum(y, dim=[0, 2, 3])
153+
chanel_sum_sqr += torch.sum(y**2, dim=[0, 2, 3])
144154

145-
for batch in tqdm.tqdm(dataloader, desc="Calculate teacher channel std", position=0, leave=True):
146-
y = self.model.teacher(batch["image"].to(self.device))
147-
distance = (y - channel_mean) ** 2
148-
means_distance.append(torch.mean(distance, dim=[0, 2, 3]))
155+
assert n is not None
156+
157+
channel_mean = chanel_sum / n
158+
159+
channel_std = (torch.sqrt((chanel_sum_sqr / n) - (channel_mean**2))).float()[None, :, None, None]
160+
channel_mean = channel_mean.float()[None, :, None, None]
149161

150-
channel_var = torch.mean(torch.stack(means_distance), dim=0)[None, :, None, None]
151-
channel_std = torch.sqrt(channel_var)
152162
return {"mean": channel_mean, "std": channel_std}
153163

154164
@torch.no_grad()

0 commit comments

Comments
 (0)