15
15
Adapted from official impl at https://github.com/jameslahm/RepViT
16
16
"""
17
17
18
- __all__ = ['RepViT ' ]
18
+ __all__ = ['RepVit ' ]
19
19
20
20
import torch .nn as nn
21
21
from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
@@ -81,7 +81,7 @@ def fuse(self):
81
81
return m
82
82
83
83
84
- class RepVGGDW (nn .Module ):
84
+ class RepVggDw (nn .Module ):
85
85
def __init__ (self , ed , kernel_size ):
86
86
super ().__init__ ()
87
87
self .conv = ConvNorm (ed , ed , kernel_size , 1 , (kernel_size - 1 ) // 2 , groups = ed )
@@ -115,7 +115,7 @@ def fuse(self):
115
115
return conv
116
116
117
117
118
- class RepViTMlp (nn .Module ):
118
+ class RepVitMlp (nn .Module ):
119
119
def __init__ (self , in_dim , hidden_dim , act_layer ):
120
120
super ().__init__ ()
121
121
self .conv1 = ConvNorm (in_dim , hidden_dim , 1 , 1 , 0 )
@@ -130,9 +130,9 @@ class RepViTBlock(nn.Module):
130
130
def __init__ (self , in_dim , mlp_ratio , kernel_size , use_se , act_layer ):
131
131
super (RepViTBlock , self ).__init__ ()
132
132
133
- self .token_mixer = RepVGGDW (in_dim , kernel_size )
133
+ self .token_mixer = RepVggDw (in_dim , kernel_size )
134
134
self .se = SqueezeExcite (in_dim , 0.25 ) if use_se else nn .Identity ()
135
- self .channel_mixer = RepViTMlp (in_dim , in_dim * mlp_ratio , act_layer )
135
+ self .channel_mixer = RepVitMlp (in_dim , in_dim * mlp_ratio , act_layer )
136
136
137
137
def forward (self , x ):
138
138
x = self .token_mixer (x )
@@ -142,7 +142,7 @@ def forward(self, x):
142
142
return identity + x
143
143
144
144
145
- class RepViTStem (nn .Module ):
145
+ class RepVitStem (nn .Module ):
146
146
def __init__ (self , in_chs , out_chs , act_layer ):
147
147
super ().__init__ ()
148
148
self .conv1 = ConvNorm (in_chs , out_chs // 2 , 3 , 2 , 1 )
@@ -154,13 +154,13 @@ def forward(self, x):
154
154
return self .conv2 (self .act1 (self .conv1 (x )))
155
155
156
156
157
- class RepViTDownsample (nn .Module ):
157
+ class RepVitDownsample (nn .Module ):
158
158
def __init__ (self , in_dim , mlp_ratio , out_dim , kernel_size , act_layer ):
159
159
super ().__init__ ()
160
160
self .pre_block = RepViTBlock (in_dim , mlp_ratio , kernel_size , use_se = False , act_layer = act_layer )
161
161
self .spatial_downsample = ConvNorm (in_dim , in_dim , kernel_size , 2 , (kernel_size - 1 ) // 2 , groups = in_dim )
162
162
self .channel_downsample = ConvNorm (in_dim , out_dim , 1 , 1 )
163
- self .ffn = RepViTMlp (out_dim , out_dim * mlp_ratio , act_layer )
163
+ self .ffn = RepVitMlp (out_dim , out_dim * mlp_ratio , act_layer )
164
164
165
165
def forward (self , x ):
166
166
x = self .pre_block (x )
@@ -171,22 +171,25 @@ def forward(self, x):
171
171
return x + identity
172
172
173
173
174
- class RepViTClassifier (nn .Module ):
175
- def __init__ (self , dim , num_classes , distillation = False ):
174
+ class RepVitClassifier (nn .Module ):
175
+ def __init__ (self , dim , num_classes , distillation = False , drop = 0. ):
176
176
super ().__init__ ()
177
+ self .head_drop = nn .Dropout (drop )
177
178
self .head = NormLinear (dim , num_classes ) if num_classes > 0 else nn .Identity ()
178
179
self .distillation = distillation
179
- self .num_classes = num_classes
180
+ self .distilled_training = False
181
+ self .num_classes = num_classes
180
182
if distillation :
181
183
self .head_dist = NormLinear (dim , num_classes ) if num_classes > 0 else nn .Identity ()
182
184
183
185
def forward (self , x ):
186
+ x = self .head_drop (x )
184
187
if self .distillation :
185
188
x1 , x2 = self .head (x ), self .head_dist (x )
186
- if (not self .training ) or torch .jit .is_scripting ():
187
- return (x1 + x2 ) / 2
188
- else :
189
+ if self .training and self .distilled_training and not torch .jit .is_scripting ():
189
190
return x1 , x2
191
+ else :
192
+ return (x1 + x2 ) / 2
190
193
else :
191
194
x = self .head (x )
192
195
return x
@@ -207,11 +210,11 @@ def fuse(self):
207
210
return head
208
211
209
212
210
- class RepViTStage (nn .Module ):
213
+ class RepVitStage (nn .Module ):
211
214
def __init__ (self , in_dim , out_dim , depth , mlp_ratio , act_layer , kernel_size = 3 , downsample = True ):
212
215
super ().__init__ ()
213
216
if downsample :
214
- self .downsample = RepViTDownsample (in_dim , mlp_ratio , out_dim , kernel_size , act_layer )
217
+ self .downsample = RepVitDownsample (in_dim , mlp_ratio , out_dim , kernel_size , act_layer )
215
218
else :
216
219
assert in_dim == out_dim
217
220
self .downsample = nn .Identity ()
@@ -230,7 +233,7 @@ def forward(self, x):
230
233
return x
231
234
232
235
233
- class RepViT (nn .Module ):
236
+ class RepVit (nn .Module ):
234
237
def __init__ (
235
238
self ,
236
239
in_chans = 3 ,
@@ -243,15 +246,16 @@ def __init__(
243
246
num_classes = 1000 ,
244
247
act_layer = nn .GELU ,
245
248
distillation = True ,
249
+ drop_rate = 0. ,
246
250
):
247
- super (RepViT , self ).__init__ ()
251
+ super (RepVit , self ).__init__ ()
248
252
self .grad_checkpointing = False
249
253
self .global_pool = global_pool
250
254
self .embed_dim = embed_dim
251
255
self .num_classes = num_classes
252
256
253
257
in_dim = embed_dim [0 ]
254
- self .stem = RepViTStem (in_chans , in_dim , act_layer )
258
+ self .stem = RepVitStem (in_chans , in_dim , act_layer )
255
259
stride = self .stem .stride
256
260
resolution = tuple ([i // p for i , p in zip (to_2tuple (img_size ), to_2tuple (stride ))])
257
261
@@ -263,7 +267,7 @@ def __init__(
263
267
for i in range (num_stages ):
264
268
downsample = True if i != 0 else False
265
269
stages .append (
266
- RepViTStage (
270
+ RepVitStage (
267
271
in_dim ,
268
272
embed_dim [i ],
269
273
depth [i ],
@@ -281,7 +285,8 @@ def __init__(
281
285
self .stages = nn .Sequential (* stages )
282
286
283
287
self .num_features = embed_dim [- 1 ]
284
- self .head = RepViTClassifier (embed_dim [- 1 ], num_classes , distillation )
288
+ self .head_drop = nn .Dropout (drop_rate )
289
+ self .head = RepVitClassifier (embed_dim [- 1 ], num_classes , distillation )
285
290
286
291
@torch .jit .ignore
287
292
def group_matcher (self , coarse = False ):
@@ -304,9 +309,13 @@ def reset_classifier(self, num_classes, global_pool=None, distillation=False):
304
309
if global_pool is not None :
305
310
self .global_pool = global_pool
306
311
self .head = (
307
- RepViTClassifier (self .embed_dim [- 1 ], num_classes , distillation ) if num_classes > 0 else nn .Identity ()
312
+ RepVitClassifier (self .embed_dim [- 1 ], num_classes , distillation ) if num_classes > 0 else nn .Identity ()
308
313
)
309
314
315
+ @torch .jit .ignore
316
+ def set_distilled_training (self , enable = True ):
317
+ self .head .distilled_training = enable
318
+
310
319
def forward_features (self , x ):
311
320
x = self .stem (x )
312
321
if self .grad_checkpointing and not torch .jit .is_scripting ():
@@ -317,8 +326,9 @@ def forward_features(self, x):
317
326
318
327
def forward_head (self , x , pre_logits : bool = False ):
319
328
if self .global_pool == 'avg' :
320
- x = nn .functional .adaptive_avg_pool2d (x , 1 ).flatten (1 )
321
- return x if pre_logits else self .head (x )
329
+ x = x .mean ((2 , 3 ), keepdim = False )
330
+ x = self .head_drop (x )
331
+ return self .head (x )
322
332
323
333
def forward (self , x ):
324
334
x = self .forward_features (x )
@@ -373,7 +383,9 @@ def _cfg(url='', **kwargs):
373
383
def _create_repvit (variant , pretrained = False , ** kwargs ):
374
384
out_indices = kwargs .pop ('out_indices' , (0 , 1 , 2 , 3 ))
375
385
model = build_model_with_cfg (
376
- RepViT , variant , pretrained , feature_cfg = dict (flatten_sequential = True , out_indices = out_indices ), ** kwargs
386
+ RepVit , variant , pretrained ,
387
+ feature_cfg = dict (flatten_sequential = True , out_indices = out_indices ),
388
+ ** kwargs ,
377
389
)
378
390
return model
379
391
0 commit comments