@@ -125,30 +125,40 @@ def prepare_imagenette_data(self) -> None:
125
125
@torch .no_grad ()
126
126
def teacher_channel_mean_std (self , dataloader : DataLoader ) -> dict [str , Tensor ]:
127
127
"""Calculate the mean and std of the teacher models activations.
128
+ Adapted from https://math.stackexchange.com/a/2148949
128
129
129
130
Args:
130
131
dataloader (DataLoader): Dataloader of the respective dataset.
131
132
132
133
Returns:
133
134
dict[str, Tensor]: Dictionary of channel-wise mean and std
134
135
"""
135
- y_means = []
136
- means_distance = []
137
136
138
- logger .info ("Calculate teacher channel mean and std" )
139
- for batch in tqdm .tqdm (dataloader , desc = "Calculate teacher channel mean" , position = 0 , leave = True ):
137
+ arrays_defined = False
138
+ n : torch .Tensor | None = None
139
+ chanel_sum : torch .Tensor | None = None
140
+ chanel_sum_sqr : torch .Tensor | None = None
141
+
142
+ for batch in tqdm .tqdm (dataloader , desc = "Calculate teacher channel mean & std" , position = 0 , leave = True ):
140
143
y = self .model .teacher (batch ["image" ].to (self .device ))
141
- y_means .append (torch .mean (y , dim = [0 , 2 , 3 ]))
144
+ if not arrays_defined :
145
+ _ , num_channels , _ , _ = y .shape
146
+ n = torch .zeros ((num_channels ,), dtype = torch .int64 , device = y .device )
147
+ chanel_sum = torch .zeros ((num_channels ,), dtype = torch .float64 , device = y .device )
148
+ chanel_sum_sqr = torch .zeros ((num_channels ,), dtype = torch .float64 , device = y .device )
149
+ arrays_defined = True
142
150
143
- channel_mean = torch .mean (torch .stack (y_means ), dim = 0 )[None , :, None , None ]
151
+ n += y [:, 0 ].numel ()
152
+ chanel_sum += torch .sum (y , dim = [0 , 2 , 3 ])
153
+ chanel_sum_sqr += torch .sum (y ** 2 , dim = [0 , 2 , 3 ])
144
154
145
- for batch in tqdm .tqdm (dataloader , desc = "Calculate teacher channel std" , position = 0 , leave = True ):
146
- y = self .model .teacher (batch ["image" ].to (self .device ))
147
- distance = (y - channel_mean ) ** 2
148
- means_distance .append (torch .mean (distance , dim = [0 , 2 , 3 ]))
155
+ assert n is not None
156
+
157
+ channel_mean = chanel_sum / n
158
+
159
+ channel_std = (torch .sqrt ((chanel_sum_sqr / n ) - (channel_mean ** 2 ))).float ()[None , :, None , None ]
160
+ channel_mean = channel_mean .float ()[None , :, None , None ]
149
161
150
- channel_var = torch .mean (torch .stack (means_distance ), dim = 0 )[None , :, None , None ]
151
- channel_std = torch .sqrt (channel_var )
152
162
return {"mean" : channel_mean , "std" : channel_std }
153
163
154
164
@torch .no_grad ()
0 commit comments