@@ -845,7 +845,9 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
845
845
"""
846
846
import numpy as np
847
847
848
- def _n2p (w , t = True ):
848
+ def _n2p (w , t = True , idx = None ):
849
+ if idx is not None :
850
+ w = w [idx ]
849
851
if w .ndim == 4 and w .shape [0 ] == w .shape [1 ] == w .shape [2 ] == 1 :
850
852
w = w .flatten ()
851
853
if t :
@@ -955,21 +957,28 @@ def _n2p(w, t=True):
955
957
956
958
mha_sub , b_sub , ln1_sub = (0 , 0 , 1 ) if big_vision else (1 , 3 , 2 )
957
959
for i , block in enumerate (model .blocks .children ()):
958
- block_prefix = f'{ prefix } Transformer/encoderblock_{ i } /'
960
+ if f'{ prefix } Transformer/encoderblock/LayerNorm_0/scale' in w :
961
+ block_prefix = f'{ prefix } Transformer/encoderblock/'
962
+ idx = i
963
+ else :
964
+ block_prefix = f'{ prefix } Transformer/encoderblock_{ i } /'
965
+ idx = None
959
966
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{ mha_sub } /'
960
- block .norm1 .weight .copy_ (_n2p (w [f'{ block_prefix } LayerNorm_0/scale' ]))
961
- block .norm1 .bias .copy_ (_n2p (w [f'{ block_prefix } LayerNorm_0/bias' ]))
967
+ block .norm1 .weight .copy_ (_n2p (w [f'{ block_prefix } LayerNorm_0/scale' ], idx = idx ))
968
+ block .norm1 .bias .copy_ (_n2p (w [f'{ block_prefix } LayerNorm_0/bias' ], idx = idx ))
962
969
block .attn .qkv .weight .copy_ (torch .cat ([
963
- _n2p (w [f'{ mha_prefix } { n } /kernel' ], t = False ).flatten (1 ).T for n in ('query' , 'key' , 'value' )]))
970
+ _n2p (w [f'{ mha_prefix } { n } /kernel' ], t = False , idx = idx ).flatten (1 ).T for n in ('query' , 'key' , 'value' )]))
964
971
block .attn .qkv .bias .copy_ (torch .cat ([
965
- _n2p (w [f'{ mha_prefix } { n } /bias' ], t = False ).reshape (- 1 ) for n in ('query' , 'key' , 'value' )]))
966
- block .attn .proj .weight .copy_ (_n2p (w [f'{ mha_prefix } out/kernel' ]).flatten (1 ))
967
- block .attn .proj .bias .copy_ (_n2p (w [f'{ mha_prefix } out/bias' ]))
968
- block .norm2 .weight .copy_ (_n2p (w [f'{ block_prefix } LayerNorm_{ ln1_sub } /scale' ]))
969
- block .norm2 .bias .copy_ (_n2p (w [f'{ block_prefix } LayerNorm_{ ln1_sub } /bias' ]))
972
+ _n2p (w [f'{ mha_prefix } { n } /bias' ], t = False , idx = idx ).reshape (- 1 ) for n in ('query' , 'key' , 'value' )]))
973
+ block .attn .proj .weight .copy_ (_n2p (w [f'{ mha_prefix } out/kernel' ], idx = idx ).flatten (1 ))
974
+ block .attn .proj .bias .copy_ (_n2p (w [f'{ mha_prefix } out/bias' ], idx = idx ))
975
+ block .norm2 .weight .copy_ (_n2p (w [f'{ block_prefix } LayerNorm_{ ln1_sub } /scale' ], idx = idx ))
976
+ block .norm2 .bias .copy_ (_n2p (w [f'{ block_prefix } LayerNorm_{ ln1_sub } /bias' ], idx = idx ))
970
977
for r in range (2 ):
971
- getattr (block .mlp , f'fc{ r + 1 } ' ).weight .copy_ (_n2p (w [f'{ block_prefix } MlpBlock_{ b_sub } /Dense_{ r } /kernel' ]))
972
- getattr (block .mlp , f'fc{ r + 1 } ' ).bias .copy_ (_n2p (w [f'{ block_prefix } MlpBlock_{ b_sub } /Dense_{ r } /bias' ]))
978
+ getattr (block .mlp , f'fc{ r + 1 } ' ).weight .copy_ (
979
+ _n2p (w [f'{ block_prefix } MlpBlock_{ b_sub } /Dense_{ r } /kernel' ], idx = idx ))
980
+ getattr (block .mlp , f'fc{ r + 1 } ' ).bias .copy_ (
981
+ _n2p (w [f'{ block_prefix } MlpBlock_{ b_sub } /Dense_{ r } /bias' ], idx = idx ))
973
982
974
983
975
984
def _convert_openai_clip (
@@ -1769,6 +1778,44 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
1769
1778
input_size = (3 , 384 , 384 ),
1770
1779
num_classes = 0 ),
1771
1780
1781
+ 'vit_so400m_patch14_siglip_gap_224.webli' : _cfg (
1782
+ hf_hub_id = 'timm/ViT-SO400M-14-SigLIP' ,
1783
+ hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1784
+ num_classes = 0 ),
1785
+ 'vit_so400m_patch14_siglip_gap_224.pali_mix' : _cfg (
1786
+ hf_hub_id = 'google/paligemma-3b-mix-224-jax' ,
1787
+ hf_hub_filename = 'paligemma-3b-mix-224.npz' ,
1788
+ custom_load = 'hf' ,
1789
+ num_classes = 0 ),
1790
+ 'vit_so400m_patch14_siglip_gap_224.pali_pt' : _cfg (
1791
+ hf_hub_id = 'google/paligemma-3b-pt-224-jax' ,
1792
+ hf_hub_filename = 'paligemma-3b-pt-224.npz' ,
1793
+ custom_load = 'hf' ,
1794
+ num_classes = 0 ),
1795
+ 'vit_so400m_patch14_siglip_gap_384.webli' : _cfg (
1796
+ hf_hub_id = 'timm/ViT-SO400M-14-SigLIP-384' ,
1797
+ hf_hub_filename = 'open_clip_pytorch_model.bin' ,
1798
+ input_size = (3 , 384 , 384 ), crop_pct = 1.0 ,
1799
+ num_classes = 0 ),
1800
+ 'vit_so400m_patch14_siglip_gap_448.pali_mix' : _cfg (
1801
+ hf_hub_id = 'google/paligemma-3b-mix-448-jax' ,
1802
+ hf_hub_filename = 'paligemma-3b-mix-448.npz' ,
1803
+ custom_load = 'hf' ,
1804
+ input_size = (3 , 448 , 448 ), crop_pct = 1.0 ,
1805
+ num_classes = 0 ),
1806
+ 'vit_so400m_patch14_siglip_gap_448.pali_pt' : _cfg (
1807
+ hf_hub_id = 'google/paligemma-3b-pt-448-jax' ,
1808
+ hf_hub_filename = 'paligemma-3b-pt-448.npz' ,
1809
+ custom_load = 'hf' ,
1810
+ input_size = (3 , 448 , 448 ), crop_pct = 1.0 ,
1811
+ num_classes = 0 ),
1812
+ 'vit_so400m_patch14_siglip_gap_896.pali_pt' : _cfg (
1813
+ hf_hub_id = 'google/paligemma-3b-pt-896-jax' ,
1814
+ hf_hub_filename = 'paligemma-3b-pt-896.npz' ,
1815
+ custom_load = 'hf' ,
1816
+ input_size = (3 , 896 , 896 ), crop_pct = 1.0 ,
1817
+ num_classes = 0 ),
1818
+
1772
1819
'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m' : _cfg (
1773
1820
hf_hub_id = 'timm/' ,
1774
1821
hf_hub_filename = 'open_clip_pytorch_model.bin' ,
@@ -2756,15 +2803,48 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT
2756
2803
return model
2757
2804
2758
2805
2759
- # @register_model
2760
- # def vit_medium_patch16_reg4_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
2761
- # model_args = dict(
2762
- # patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True,
2763
- # no_embed_class=True, reg_tokens=4,
2764
- # )
2765
- # model = _create_vision_transformer(
2766
- # 'vit_medium_patch16_reg4_256', pretrained=pretrained, **dict(model_args, **kwargs))
2767
- # return model
2806
+ @register_model
2807
+ def vit_so400m_patch14_siglip_gap_224 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2808
+ model_args = dict (
2809
+ patch_size = 14 , embed_dim = 1152 , depth = 27 , num_heads = 16 , mlp_ratio = 3.7362 ,
2810
+ class_token = False , global_pool = 'avg' , fc_norm = False ,
2811
+ )
2812
+ model = _create_vision_transformer (
2813
+ 'vit_so400m_patch14_siglip_gap_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2814
+ return model
2815
+
2816
+
2817
+ @register_model
2818
+ def vit_so400m_patch14_siglip_gap_384 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2819
+ model_args = dict (
2820
+ patch_size = 14 , embed_dim = 1152 , depth = 27 , num_heads = 16 , mlp_ratio = 3.7362 ,
2821
+ class_token = False , global_pool = 'avg' , fc_norm = False ,
2822
+ )
2823
+ model = _create_vision_transformer (
2824
+ 'vit_so400m_patch14_siglip_gap_384' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2825
+ return model
2826
+
2827
+
2828
+ @register_model
2829
+ def vit_so400m_patch14_siglip_gap_448 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2830
+ model_args = dict (
2831
+ patch_size = 14 , embed_dim = 1152 , depth = 27 , num_heads = 16 , mlp_ratio = 3.7362 ,
2832
+ class_token = False , global_pool = 'avg' , fc_norm = False ,
2833
+ )
2834
+ model = _create_vision_transformer (
2835
+ 'vit_so400m_patch14_siglip_gap_448' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2836
+ return model
2837
+
2838
+
2839
+ @register_model
2840
+ def vit_so400m_patch14_siglip_gap_896 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2841
+ model_args = dict (
2842
+ patch_size = 14 , embed_dim = 1152 , depth = 27 , num_heads = 16 , mlp_ratio = 3.7362 ,
2843
+ class_token = False , global_pool = 'avg' , fc_norm = False ,
2844
+ )
2845
+ model = _create_vision_transformer (
2846
+ 'vit_so400m_patch14_siglip_gap_896' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2847
+ return model
2768
2848
2769
2849
2770
2850
@register_model
0 commit comments