@@ -386,6 +386,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
386
386
return self ._forward (x )
387
387
388
388
389
+ def global_pool_nlc (
390
+ x : torch .Tensor ,
391
+ pool_type : str = 'token' ,
392
+ num_prefix_tokens : int = 1 ,
393
+ reduce_include_prefix : bool = False ,
394
+ ):
395
+ if not pool_type :
396
+ return x
397
+
398
+ if pool_type == 'token' :
399
+ x = x [:, 0 ] # class token
400
+ else :
401
+ x = x if reduce_include_prefix else x [:, num_prefix_tokens :]
402
+ if pool_type == 'avg' :
403
+ x = x .mean (dim = 1 )
404
+ elif pool_type == 'avgmax' :
405
+ x = 0.5 * (x .amax (dim = 1 ) + x .mean (dim = 1 ))
406
+ elif pool_type == 'max' :
407
+ x = x .amax (dim = 1 )
408
+ else :
409
+ assert not pool_type , f'Unknown pool type { pool_type } '
410
+
411
+ return x
412
+
413
+
389
414
class VisionTransformer (nn .Module ):
390
415
""" Vision Transformer
391
416
@@ -400,7 +425,7 @@ def __init__(
400
425
patch_size : Union [int , Tuple [int , int ]] = 16 ,
401
426
in_chans : int = 3 ,
402
427
num_classes : int = 1000 ,
403
- global_pool : Literal ['' , 'avg' , 'max' , 'token' , 'map' ] = 'token' ,
428
+ global_pool : Literal ['' , 'avg' , 'avgmax' , ' max' , 'token' , 'map' ] = 'token' ,
404
429
embed_dim : int = 768 ,
405
430
depth : int = 12 ,
406
431
num_heads : int = 12 ,
@@ -459,10 +484,10 @@ def __init__(
459
484
block_fn: Transformer block layer.
460
485
"""
461
486
super ().__init__ ()
462
- assert global_pool in ('' , 'avg' , 'max' , 'token' , 'map' )
487
+ assert global_pool in ('' , 'avg' , 'avgmax' , ' max' , 'token' , 'map' )
463
488
assert class_token or global_pool != 'token'
464
489
assert pos_embed in ('' , 'none' , 'learn' )
465
- use_fc_norm = global_pool in [ 'avg' , 'max' ] if fc_norm is None else fc_norm
490
+ use_fc_norm = global_pool in ( 'avg' , 'avgmax' , ' max') if fc_norm is None else fc_norm
466
491
norm_layer = get_norm_layer (norm_layer ) or partial (nn .LayerNorm , eps = 1e-6 )
467
492
act_layer = get_act_layer (act_layer ) or nn .GELU
468
493
@@ -596,10 +621,10 @@ def set_grad_checkpointing(self, enable: bool = True) -> None:
596
621
def get_classifier (self ) -> nn .Module :
597
622
return self .head
598
623
599
- def reset_classifier (self , num_classes : int , global_pool = None ) -> None :
624
+ def reset_classifier (self , num_classes : int , global_pool : Optional [ str ] = None ):
600
625
self .num_classes = num_classes
601
626
if global_pool is not None :
602
- assert global_pool in ('' , 'avg' , 'token' , 'map' )
627
+ assert global_pool in ('' , 'avg' , 'avgmax' , 'max' , ' token' , 'map' )
603
628
if global_pool == 'map' and self .attn_pool is None :
604
629
assert False , "Cannot currently add attention pooling in reset_classifier()."
605
630
elif global_pool != 'map ' and self .attn_pool is not None :
@@ -756,15 +781,16 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
756
781
x = self .norm (x )
757
782
return x
758
783
759
- def forward_head (self , x : torch .Tensor , pre_logits : bool = False ) -> torch .Tensor :
784
+ def pool (self , x : torch .Tensor , pool_type : Optional [ str ] = None ) -> torch .Tensor :
760
785
if self .attn_pool is not None :
761
786
x = self .attn_pool (x )
762
- elif self .global_pool == 'avg' :
763
- x = x [:, self .num_prefix_tokens :].mean (dim = 1 )
764
- elif self .global_pool == 'max' :
765
- x , _ = torch .max (x [:, self .num_prefix_tokens :], dim = 1 )
766
- elif self .global_pool :
767
- x = x [:, 0 ] # class token
787
+ return x
788
+ pool_type = self .global_pool if pool_type is None else pool_type
789
+ x = global_pool_nlc (x , pool_type = pool_type , num_prefix_tokens = self .num_prefix_tokens )
790
+ return x
791
+
792
+ def forward_head (self , x : torch .Tensor , pre_logits : bool = False ) -> torch .Tensor :
793
+ x = self .pool (x )
768
794
x = self .fc_norm (x )
769
795
x = self .head_drop (x )
770
796
return x if pre_logits else self .head (x )
0 commit comments