Skip to content

Commit b1a6f4a

Browse files
committed
Some missed reset_classifier() type annotations
1 parent 71101eb commit b1a6f4a

16 files changed

+33
-23
lines changed

timm/models/efficientnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def set_grad_checkpointing(self, enable=True):
156156
def get_classifier(self) -> nn.Module:
157157
return self.classifier
158158

159-
def reset_classifier(self, num_classes, global_pool='avg'):
159+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
160160
self.num_classes = num_classes
161161
self.global_pool, self.classifier = create_classifier(
162162
self.num_features, self.num_classes, pool_type=global_pool)

timm/models/ghostnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def set_grad_checkpointing(self, enable=True):
273273
def get_classifier(self) -> nn.Module:
274274
return self.classifier
275275

276-
def reset_classifier(self, num_classes, global_pool='avg'):
276+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
277277
self.num_classes = num_classes
278278
# cannot meaningfully change pooling of efficient head after creation
279279
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)

timm/models/hrnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,7 @@ def set_grad_checkpointing(self, enable=True):
739739
def get_classifier(self) -> nn.Module:
740740
return self.classifier
741741

742-
def reset_classifier(self, num_classes, global_pool='avg'):
742+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
743743
self.num_classes = num_classes
744744
self.global_pool, self.classifier = create_classifier(
745745
self.num_features, self.num_classes, pool_type=global_pool)

timm/models/inception_v4.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def set_grad_checkpointing(self, enable=True):
280280
def get_classifier(self) -> nn.Module:
281281
return self.last_linear
282282

283-
def reset_classifier(self, num_classes, global_pool='avg'):
283+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
284284
self.num_classes = num_classes
285285
self.global_pool, self.last_linear = create_classifier(
286286
self.num_features, self.num_classes, pool_type=global_pool)

timm/models/metaformer.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
# See the License for the specific language governing permissions and
2727
# limitations under the License.
2828

29-
3029
from collections import OrderedDict
3130
from functools import partial
31+
from typing import Optional
3232

3333
import torch
3434
import torch.nn as nn
@@ -548,7 +548,7 @@ def __init__(
548548
# if using MlpHead, dropout is handled by MlpHead
549549
if num_classes > 0:
550550
if self.use_mlp_head:
551-
# FIXME hidden size
551+
# FIXME not actually returning mlp hidden state right now as pre-logits.
552552
final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate)
553553
self.head_hidden_size = self.num_features
554554
else:
@@ -583,7 +583,7 @@ def set_grad_checkpointing(self, enable=True):
583583
def get_classifier(self) -> nn.Module:
584584
return self.head.fc
585585

586-
def reset_classifier(self, num_classes=0, global_pool=None):
586+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
587587
if global_pool is not None:
588588
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
589589
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()

timm/models/nasnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,7 @@ def set_grad_checkpointing(self, enable=True):
518518
def get_classifier(self) -> nn.Module:
519519
return self.last_linear
520520

521-
def reset_classifier(self, num_classes, global_pool='avg'):
521+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
522522
self.num_classes = num_classes
523523
self.global_pool, self.last_linear = create_classifier(
524524
self.num_features, self.num_classes, pool_type=global_pool)

timm/models/pnasnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def set_grad_checkpointing(self, enable=True):
307307
def get_classifier(self) -> nn.Module:
308308
return self.last_linear
309309

310-
def reset_classifier(self, num_classes, global_pool='avg'):
310+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
311311
self.num_classes = num_classes
312312
self.global_pool, self.last_linear = create_classifier(
313313
self.num_features, self.num_classes, pool_type=global_pool)

timm/models/regnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ def set_grad_checkpointing(self, enable=True):
514514
def get_classifier(self) -> nn.Module:
515515
return self.head.fc
516516

517-
def reset_classifier(self, num_classes, global_pool='avg'):
517+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
518518
self.head.reset(num_classes, pool_type=global_pool)
519519

520520
def forward_intermediates(

timm/models/rexnet.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from functools import partial
1414
from math import ceil
15+
from typing import Optional
1516

1617
import torch
1718
import torch.nn as nn
@@ -229,7 +230,7 @@ def set_grad_checkpointing(self, enable=True):
229230
def get_classifier(self) -> nn.Module:
230231
return self.head.fc
231232

232-
def reset_classifier(self, num_classes, global_pool='avg'):
233+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
233234
self.num_classes = num_classes
234235
self.head.reset(num_classes, global_pool)
235236

timm/models/selecsls.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def set_grad_checkpointing(self, enable=True):
161161
def get_classifier(self) -> nn.Module:
162162
return self.fc
163163

164-
def reset_classifier(self, num_classes, global_pool='avg'):
164+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
165165
self.num_classes = num_classes
166166
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
167167

timm/models/senet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def set_grad_checkpointing(self, enable=True):
337337
def get_classifier(self) -> nn.Module:
338338
return self.last_linear
339339

340-
def reset_classifier(self, num_classes, global_pool='avg'):
340+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
341341
self.num_classes = num_classes
342342
self.global_pool, self.last_linear = create_classifier(
343343
self.num_features, self.num_classes, pool_type=global_pool)

timm/models/vision_transformer_relpos.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def set_grad_checkpointing(self, enable=True):
381381
def get_classifier(self) -> nn.Module:
382382
return self.head
383383

384-
def reset_classifier(self, num_classes: int, global_pool=None):
384+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
385385
self.num_classes = num_classes
386386
if global_pool is not None:
387387
assert global_pool in ('', 'avg', 'token')

timm/models/vision_transformer_sam.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ def set_grad_checkpointing(self, enable=True):
536536
def get_classifier(self) -> nn.Module:
537537
return self.head
538538

539-
def reset_classifier(self, num_classes=0, global_pool=None):
539+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
540540
self.head.reset(num_classes, global_pool)
541541

542542
def forward_intermediates(

timm/models/vovnet.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
Hacked together by / Copyright 2020 Ross Wightman
1212
"""
1313

14-
from typing import List
14+
from typing import List, Optional
1515

1616
import torch
1717
import torch.nn as nn
@@ -134,9 +134,17 @@ def __init__(
134134
else:
135135
drop_path = None
136136
blocks += [OsaBlock(
137-
in_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0, depthwise=depthwise,
138-
attn=attn if last_block else '', norm_layer=norm_layer, act_layer=act_layer, drop_path=drop_path)
139-
]
137+
in_chs,
138+
mid_chs,
139+
out_chs,
140+
layer_per_block,
141+
residual=residual and i > 0,
142+
depthwise=depthwise,
143+
attn=attn if last_block else '',
144+
norm_layer=norm_layer,
145+
act_layer=act_layer,
146+
drop_path=drop_path
147+
)]
140148
in_chs = out_chs
141149
self.blocks = nn.Sequential(*blocks)
142150

@@ -252,8 +260,9 @@ def set_grad_checkpointing(self, enable=True):
252260
def get_classifier(self) -> nn.Module:
253261
return self.head.fc
254262

255-
def reset_classifier(self, num_classes, global_pool='avg'):
256-
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
263+
def reset_classifier(self, num_classes, global_pool: Optional[str] = None):
264+
self.num_classes = num_classes
265+
self.head.reset(num_classes, global_pool)
257266

258267
def forward_features(self, x):
259268
x = self.stem(x)

timm/models/xception.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def set_grad_checkpointing(self, enable=True):
174174
def get_classifier(self) -> nn.Module:
175175
return self.fc
176176

177-
def reset_classifier(self, num_classes, global_pool='avg'):
177+
def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
178178
self.num_classes = num_classes
179179
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
180180

timm/models/xception_aligned.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def set_grad_checkpointing(self, enable=True):
274274
def get_classifier(self) -> nn.Module:
275275
return self.head.fc
276276

277-
def reset_classifier(self, num_classes, global_pool='avg'):
277+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
278278
self.head.reset(num_classes, pool_type=global_pool)
279279

280280
def forward_features(self, x):

0 commit comments

Comments
 (0)