Skip to content

Commit 6653747

Browse files
authored
Merge pull request #2178 from huggingface/pali_siglip
Support loading of PaliGemma weights into GAP variants of SigLIP ViT.
2 parents 04462f5 + 7b3b11b commit 6653747

File tree

3 files changed

+116
-23
lines changed

3 files changed

+116
-23
lines changed

timm/models/_builder.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from timm.models._features import FeatureListNet, FeatureDictNet, FeatureHookNet, FeatureGetterNet
1111
from timm.models._features_fx import FeatureGraphNet
1212
from timm.models._helpers import load_state_dict
13-
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
13+
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf,\
14+
load_custom_from_hf
1415
from timm.models._manipulate import adapt_input_conv
1516
from timm.models._pretrained import PretrainedCfg
1617
from timm.models._prune import adapt_model_from_file
@@ -185,7 +186,12 @@ def load_pretrained(
185186
elif load_from == 'hf-hub':
186187
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
187188
if isinstance(pretrained_loc, (list, tuple)):
188-
state_dict = load_state_dict_from_hf(*pretrained_loc)
189+
custom_load = pretrained_cfg.get('custom_load', False)
190+
if isinstance(custom_load, str) and custom_load == 'hf':
191+
load_custom_from_hf(*pretrained_loc, model)
192+
return
193+
else:
194+
state_dict = load_state_dict_from_hf(*pretrained_loc)
189195
else:
190196
state_dict = load_state_dict_from_hf(pretrained_loc)
191197
else:

timm/models/_hub.py

+7
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,13 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
190190
return torch.load(cached_file, map_location='cpu')
191191

192192

193+
def load_custom_from_hf(model_id: str, filename: str, model: torch.nn.Module):
194+
assert has_hf_hub(True)
195+
hf_model_id, hf_revision = hf_split(model_id)
196+
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
197+
return model.load_pretrained(cached_file)
198+
199+
193200
def save_config_for_hf(
194201
model,
195202
config_path: str,

timm/models/vision_transformer.py

+101-21
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,9 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
845845
"""
846846
import numpy as np
847847

848-
def _n2p(w, t=True):
848+
def _n2p(w, t=True, idx=None):
849+
if idx is not None:
850+
w = w[idx]
849851
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
850852
w = w.flatten()
851853
if t:
@@ -955,21 +957,28 @@ def _n2p(w, t=True):
955957

956958
mha_sub, b_sub, ln1_sub = (0, 0, 1) if big_vision else (1, 3, 2)
957959
for i, block in enumerate(model.blocks.children()):
958-
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
960+
if f'{prefix}Transformer/encoderblock/LayerNorm_0/scale' in w:
961+
block_prefix = f'{prefix}Transformer/encoderblock/'
962+
idx = i
963+
else:
964+
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
965+
idx = None
959966
mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/'
960-
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
961-
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
967+
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'], idx=idx))
968+
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'], idx=idx))
962969
block.attn.qkv.weight.copy_(torch.cat([
963-
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
970+
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False, idx=idx).flatten(1).T for n in ('query', 'key', 'value')]))
964971
block.attn.qkv.bias.copy_(torch.cat([
965-
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
966-
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
967-
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
968-
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale']))
969-
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias']))
972+
_n2p(w[f'{mha_prefix}{n}/bias'], t=False, idx=idx).reshape(-1) for n in ('query', 'key', 'value')]))
973+
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel'], idx=idx).flatten(1))
974+
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'], idx=idx))
975+
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'], idx=idx))
976+
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'], idx=idx))
970977
for r in range(2):
971-
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel']))
972-
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias']))
978+
getattr(block.mlp, f'fc{r + 1}').weight.copy_(
979+
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'], idx=idx))
980+
getattr(block.mlp, f'fc{r + 1}').bias.copy_(
981+
_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'], idx=idx))
973982

974983

975984
def _convert_openai_clip(
@@ -1769,6 +1778,44 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
17691778
input_size=(3, 384, 384),
17701779
num_classes=0),
17711780

1781+
'vit_so400m_patch14_siglip_gap_224.webli': _cfg(
1782+
hf_hub_id='timm/ViT-SO400M-14-SigLIP',
1783+
hf_hub_filename='open_clip_pytorch_model.bin',
1784+
num_classes=0),
1785+
'vit_so400m_patch14_siglip_gap_224.pali_mix': _cfg(
1786+
hf_hub_id='google/paligemma-3b-mix-224-jax',
1787+
hf_hub_filename='paligemma-3b-mix-224.npz',
1788+
custom_load='hf',
1789+
num_classes=0),
1790+
'vit_so400m_patch14_siglip_gap_224.pali_pt': _cfg(
1791+
hf_hub_id='google/paligemma-3b-pt-224-jax',
1792+
hf_hub_filename='paligemma-3b-pt-224.npz',
1793+
custom_load='hf',
1794+
num_classes=0),
1795+
'vit_so400m_patch14_siglip_gap_384.webli': _cfg(
1796+
hf_hub_id='timm/ViT-SO400M-14-SigLIP-384',
1797+
hf_hub_filename='open_clip_pytorch_model.bin',
1798+
input_size=(3, 384, 384), crop_pct=1.0,
1799+
num_classes=0),
1800+
'vit_so400m_patch14_siglip_gap_448.pali_mix': _cfg(
1801+
hf_hub_id='google/paligemma-3b-mix-448-jax',
1802+
hf_hub_filename='paligemma-3b-mix-448.npz',
1803+
custom_load='hf',
1804+
input_size=(3, 448, 448), crop_pct=1.0,
1805+
num_classes=0),
1806+
'vit_so400m_patch14_siglip_gap_448.pali_pt': _cfg(
1807+
hf_hub_id='google/paligemma-3b-pt-448-jax',
1808+
hf_hub_filename='paligemma-3b-pt-448.npz',
1809+
custom_load='hf',
1810+
input_size=(3, 448, 448), crop_pct=1.0,
1811+
num_classes=0),
1812+
'vit_so400m_patch14_siglip_gap_896.pali_pt': _cfg(
1813+
hf_hub_id='google/paligemma-3b-pt-896-jax',
1814+
hf_hub_filename='paligemma-3b-pt-896.npz',
1815+
custom_load='hf',
1816+
input_size=(3, 896, 896), crop_pct=1.0,
1817+
num_classes=0),
1818+
17721819
'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m': _cfg(
17731820
hf_hub_id='timm/',
17741821
hf_hub_filename='open_clip_pytorch_model.bin',
@@ -2756,15 +2803,48 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT
27562803
return model
27572804

27582805

2759-
# @register_model
2760-
# def vit_medium_patch16_reg4_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
2761-
# model_args = dict(
2762-
# patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True,
2763-
# no_embed_class=True, reg_tokens=4,
2764-
# )
2765-
# model = _create_vision_transformer(
2766-
# 'vit_medium_patch16_reg4_256', pretrained=pretrained, **dict(model_args, **kwargs))
2767-
# return model
2806+
@register_model
2807+
def vit_so400m_patch14_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
2808+
model_args = dict(
2809+
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
2810+
class_token=False, global_pool='avg', fc_norm=False,
2811+
)
2812+
model = _create_vision_transformer(
2813+
'vit_so400m_patch14_siglip_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
2814+
return model
2815+
2816+
2817+
@register_model
2818+
def vit_so400m_patch14_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
2819+
model_args = dict(
2820+
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
2821+
class_token=False, global_pool='avg', fc_norm=False,
2822+
)
2823+
model = _create_vision_transformer(
2824+
'vit_so400m_patch14_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
2825+
return model
2826+
2827+
2828+
@register_model
2829+
def vit_so400m_patch14_siglip_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer:
2830+
model_args = dict(
2831+
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
2832+
class_token=False, global_pool='avg', fc_norm=False,
2833+
)
2834+
model = _create_vision_transformer(
2835+
'vit_so400m_patch14_siglip_gap_448', pretrained=pretrained, **dict(model_args, **kwargs))
2836+
return model
2837+
2838+
2839+
@register_model
2840+
def vit_so400m_patch14_siglip_gap_896(pretrained: bool = False, **kwargs) -> VisionTransformer:
2841+
model_args = dict(
2842+
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362,
2843+
class_token=False, global_pool='avg', fc_norm=False,
2844+
)
2845+
model = _create_vision_transformer(
2846+
'vit_so400m_patch14_siglip_gap_896', pretrained=pretrained, **dict(model_args, **kwargs))
2847+
return model
27682848

27692849

27702850
@register_model

0 commit comments

Comments
 (0)