Skip to content

Commit 71101eb

Browse files
committed
Refactor vit pooling to add more reduction options, separately callable
1 parent 9567cf6 commit 71101eb

File tree

1 file changed

+38
-12
lines changed

1 file changed

+38
-12
lines changed

timm/models/vision_transformer.py

+38-12
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
386386
return self._forward(x)
387387

388388

389+
def global_pool_nlc(
390+
x: torch.Tensor,
391+
pool_type: str = 'token',
392+
num_prefix_tokens: int = 1,
393+
reduce_include_prefix: bool = False,
394+
):
395+
if not pool_type:
396+
return x
397+
398+
if pool_type == 'token':
399+
x = x[:, 0] # class token
400+
else:
401+
x = x if reduce_include_prefix else x[:, num_prefix_tokens:]
402+
if pool_type == 'avg':
403+
x = x.mean(dim=1)
404+
elif pool_type == 'avgmax':
405+
x = 0.5 * (x.amax(dim=1) + x.mean(dim=1))
406+
elif pool_type == 'max':
407+
x = x.amax(dim=1)
408+
else:
409+
assert not pool_type, f'Unknown pool type {pool_type}'
410+
411+
return x
412+
413+
389414
class VisionTransformer(nn.Module):
390415
""" Vision Transformer
391416
@@ -400,7 +425,7 @@ def __init__(
400425
patch_size: Union[int, Tuple[int, int]] = 16,
401426
in_chans: int = 3,
402427
num_classes: int = 1000,
403-
global_pool: Literal['', 'avg', 'max', 'token', 'map'] = 'token',
428+
global_pool: Literal['', 'avg', 'avgmax', 'max', 'token', 'map'] = 'token',
404429
embed_dim: int = 768,
405430
depth: int = 12,
406431
num_heads: int = 12,
@@ -459,10 +484,10 @@ def __init__(
459484
block_fn: Transformer block layer.
460485
"""
461486
super().__init__()
462-
assert global_pool in ('', 'avg', 'max', 'token', 'map')
487+
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
463488
assert class_token or global_pool != 'token'
464489
assert pos_embed in ('', 'none', 'learn')
465-
use_fc_norm = global_pool in ['avg', 'max'] if fc_norm is None else fc_norm
490+
use_fc_norm = global_pool in ('avg', 'avgmax', 'max') if fc_norm is None else fc_norm
466491
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
467492
act_layer = get_act_layer(act_layer) or nn.GELU
468493

@@ -596,10 +621,10 @@ def set_grad_checkpointing(self, enable: bool = True) -> None:
596621
def get_classifier(self) -> nn.Module:
597622
return self.head
598623

599-
def reset_classifier(self, num_classes: int, global_pool = None) -> None:
624+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
600625
self.num_classes = num_classes
601626
if global_pool is not None:
602-
assert global_pool in ('', 'avg', 'token', 'map')
627+
assert global_pool in ('', 'avg', 'avgmax', 'max', 'token', 'map')
603628
if global_pool == 'map' and self.attn_pool is None:
604629
assert False, "Cannot currently add attention pooling in reset_classifier()."
605630
elif global_pool != 'map ' and self.attn_pool is not None:
@@ -756,15 +781,16 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
756781
x = self.norm(x)
757782
return x
758783

759-
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
784+
def pool(self, x: torch.Tensor, pool_type: Optional[str] = None) -> torch.Tensor:
760785
if self.attn_pool is not None:
761786
x = self.attn_pool(x)
762-
elif self.global_pool == 'avg':
763-
x = x[:, self.num_prefix_tokens:].mean(dim=1)
764-
elif self.global_pool == 'max':
765-
x, _ = torch.max(x[:, self.num_prefix_tokens:], dim=1)
766-
elif self.global_pool:
767-
x = x[:, 0] # class token
787+
return x
788+
pool_type = self.global_pool if pool_type is None else pool_type
789+
x = global_pool_nlc(x, pool_type=pool_type, num_prefix_tokens=self.num_prefix_tokens)
790+
return x
791+
792+
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
793+
x = self.pool(x)
768794
x = self.fc_norm(x)
769795
x = self.head_drop(x)
770796
return x if pre_logits else self.head(x)

0 commit comments

Comments
 (0)