Skip to content

Commit c692715

Browse files
committed
Some RepVit tweaks
* add head dropout to RepVit as all models have that arg * default train to non-distilled head output via distilled_training flag (set_distilled_training) so fine-tune works by default w/o distillation script * camel case naming tweaks to match other models
1 parent f677190 commit c692715

File tree

1 file changed

+37
-25
lines changed

1 file changed

+37
-25
lines changed

timm/models/repvit.py

+37-25
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
Adapted from official impl at https://github.com/jameslahm/RepViT
1616
"""
1717

18-
__all__ = ['RepViT']
18+
__all__ = ['RepVit']
1919

2020
import torch.nn as nn
2121
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
@@ -81,7 +81,7 @@ def fuse(self):
8181
return m
8282

8383

84-
class RepVGGDW(nn.Module):
84+
class RepVggDw(nn.Module):
8585
def __init__(self, ed, kernel_size):
8686
super().__init__()
8787
self.conv = ConvNorm(ed, ed, kernel_size, 1, (kernel_size - 1) // 2, groups=ed)
@@ -115,7 +115,7 @@ def fuse(self):
115115
return conv
116116

117117

118-
class RepViTMlp(nn.Module):
118+
class RepVitMlp(nn.Module):
119119
def __init__(self, in_dim, hidden_dim, act_layer):
120120
super().__init__()
121121
self.conv1 = ConvNorm(in_dim, hidden_dim, 1, 1, 0)
@@ -130,9 +130,9 @@ class RepViTBlock(nn.Module):
130130
def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer):
131131
super(RepViTBlock, self).__init__()
132132

133-
self.token_mixer = RepVGGDW(in_dim, kernel_size)
133+
self.token_mixer = RepVggDw(in_dim, kernel_size)
134134
self.se = SqueezeExcite(in_dim, 0.25) if use_se else nn.Identity()
135-
self.channel_mixer = RepViTMlp(in_dim, in_dim * mlp_ratio, act_layer)
135+
self.channel_mixer = RepVitMlp(in_dim, in_dim * mlp_ratio, act_layer)
136136

137137
def forward(self, x):
138138
x = self.token_mixer(x)
@@ -142,7 +142,7 @@ def forward(self, x):
142142
return identity + x
143143

144144

145-
class RepViTStem(nn.Module):
145+
class RepVitStem(nn.Module):
146146
def __init__(self, in_chs, out_chs, act_layer):
147147
super().__init__()
148148
self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1)
@@ -154,13 +154,13 @@ def forward(self, x):
154154
return self.conv2(self.act1(self.conv1(x)))
155155

156156

157-
class RepViTDownsample(nn.Module):
157+
class RepVitDownsample(nn.Module):
158158
def __init__(self, in_dim, mlp_ratio, out_dim, kernel_size, act_layer):
159159
super().__init__()
160160
self.pre_block = RepViTBlock(in_dim, mlp_ratio, kernel_size, use_se=False, act_layer=act_layer)
161161
self.spatial_downsample = ConvNorm(in_dim, in_dim, kernel_size, 2, (kernel_size - 1) // 2, groups=in_dim)
162162
self.channel_downsample = ConvNorm(in_dim, out_dim, 1, 1)
163-
self.ffn = RepViTMlp(out_dim, out_dim * mlp_ratio, act_layer)
163+
self.ffn = RepVitMlp(out_dim, out_dim * mlp_ratio, act_layer)
164164

165165
def forward(self, x):
166166
x = self.pre_block(x)
@@ -171,22 +171,25 @@ def forward(self, x):
171171
return x + identity
172172

173173

174-
class RepViTClassifier(nn.Module):
175-
def __init__(self, dim, num_classes, distillation=False):
174+
class RepVitClassifier(nn.Module):
175+
def __init__(self, dim, num_classes, distillation=False, drop=0.):
176176
super().__init__()
177+
self.head_drop = nn.Dropout(drop)
177178
self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
178179
self.distillation = distillation
179-
self.num_classes=num_classes
180+
self.distilled_training = False
181+
self.num_classes = num_classes
180182
if distillation:
181183
self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
182184

183185
def forward(self, x):
186+
x = self.head_drop(x)
184187
if self.distillation:
185188
x1, x2 = self.head(x), self.head_dist(x)
186-
if (not self.training) or torch.jit.is_scripting():
187-
return (x1 + x2) / 2
188-
else:
189+
if self.training and self.distilled_training and not torch.jit.is_scripting():
189190
return x1, x2
191+
else:
192+
return (x1 + x2) / 2
190193
else:
191194
x = self.head(x)
192195
return x
@@ -207,11 +210,11 @@ def fuse(self):
207210
return head
208211

209212

210-
class RepViTStage(nn.Module):
213+
class RepVitStage(nn.Module):
211214
def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True):
212215
super().__init__()
213216
if downsample:
214-
self.downsample = RepViTDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer)
217+
self.downsample = RepVitDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer)
215218
else:
216219
assert in_dim == out_dim
217220
self.downsample = nn.Identity()
@@ -230,7 +233,7 @@ def forward(self, x):
230233
return x
231234

232235

233-
class RepViT(nn.Module):
236+
class RepVit(nn.Module):
234237
def __init__(
235238
self,
236239
in_chans=3,
@@ -243,15 +246,16 @@ def __init__(
243246
num_classes=1000,
244247
act_layer=nn.GELU,
245248
distillation=True,
249+
drop_rate=0.,
246250
):
247-
super(RepViT, self).__init__()
251+
super(RepVit, self).__init__()
248252
self.grad_checkpointing = False
249253
self.global_pool = global_pool
250254
self.embed_dim = embed_dim
251255
self.num_classes = num_classes
252256

253257
in_dim = embed_dim[0]
254-
self.stem = RepViTStem(in_chans, in_dim, act_layer)
258+
self.stem = RepVitStem(in_chans, in_dim, act_layer)
255259
stride = self.stem.stride
256260
resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))])
257261

@@ -263,7 +267,7 @@ def __init__(
263267
for i in range(num_stages):
264268
downsample = True if i != 0 else False
265269
stages.append(
266-
RepViTStage(
270+
RepVitStage(
267271
in_dim,
268272
embed_dim[i],
269273
depth[i],
@@ -281,7 +285,8 @@ def __init__(
281285
self.stages = nn.Sequential(*stages)
282286

283287
self.num_features = embed_dim[-1]
284-
self.head = RepViTClassifier(embed_dim[-1], num_classes, distillation)
288+
self.head_drop = nn.Dropout(drop_rate)
289+
self.head = RepVitClassifier(embed_dim[-1], num_classes, distillation)
285290

286291
@torch.jit.ignore
287292
def group_matcher(self, coarse=False):
@@ -304,9 +309,13 @@ def reset_classifier(self, num_classes, global_pool=None, distillation=False):
304309
if global_pool is not None:
305310
self.global_pool = global_pool
306311
self.head = (
307-
RepViTClassifier(self.embed_dim[-1], num_classes, distillation) if num_classes > 0 else nn.Identity()
312+
RepVitClassifier(self.embed_dim[-1], num_classes, distillation) if num_classes > 0 else nn.Identity()
308313
)
309314

315+
@torch.jit.ignore
316+
def set_distilled_training(self, enable=True):
317+
self.head.distilled_training = enable
318+
310319
def forward_features(self, x):
311320
x = self.stem(x)
312321
if self.grad_checkpointing and not torch.jit.is_scripting():
@@ -317,8 +326,9 @@ def forward_features(self, x):
317326

318327
def forward_head(self, x, pre_logits: bool = False):
319328
if self.global_pool == 'avg':
320-
x = nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
321-
return x if pre_logits else self.head(x)
329+
x = x.mean((2, 3), keepdim=False)
330+
x = self.head_drop(x)
331+
return self.head(x)
322332

323333
def forward(self, x):
324334
x = self.forward_features(x)
@@ -373,7 +383,9 @@ def _cfg(url='', **kwargs):
373383
def _create_repvit(variant, pretrained=False, **kwargs):
374384
out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
375385
model = build_model_with_cfg(
376-
RepViT, variant, pretrained, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs
386+
RepVit, variant, pretrained,
387+
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
388+
**kwargs,
377389
)
378390
return model
379391

0 commit comments

Comments
 (0)