Skip to content

Commit 627fbdb

Browse files
authored
🚨🚨🚨 Fix UperNet model and add pretrained checkpoints (#1124)
* not relevant: fix predict method * not relevant: one-line refactor timm-universal * BREAKING: change UPerNet model interface (decoder_channels) * BREAKING: huge refactoring of upernet decoder for weight compat * Add conversion script for UperNet * Add example * Update conversion * Remove prints from decoder * Fixup * Fix torch_scriptable * Update init * Update export test * Update readme * bump pillow
1 parent d9ab82a commit 627fbdb

File tree

11 files changed

+558
-83
lines changed

11 files changed

+558
-83
lines changed

Diff for: README.md

+15-12
Original file line numberDiff line numberDiff line change
@@ -106,25 +106,28 @@ Congratulations! You are done! Now you can train your model with your favorite f
106106
| **Train** clothes binary segmentation by @ternaus | [Repo](https://github.com/ternaus/cloths_segmentation) | |
107107
| **Load and inference** pretrained Segformer | [Notebook](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/examples/segformer_inference_pretrained.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/segformer_inference_pretrained.ipynb) |
108108
| **Load and inference** pretrained DPT | [Notebook](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/examples/dpt_inference_pretrained.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/dpt_inference_pretrained.ipynb) |
109+
| **Load and inference** pretrained DPT | [Notebook](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/examples/upernet_inference_pretrained.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/upernet_inference_pretrained.ipynb) |
109110
| **Save and load** models locally / to HuggingFace Hub |[Notebook](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/examples/save_load_model_and_share_with_hf_hub.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/save_load_model_and_share_with_hf_hub.ipynb)
110111
| **Export** trained model to ONNX | [Notebook](https://github.com/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/qubvel/segmentation_models.pytorch/blob/main/examples/convert_to_onnx.ipynb) |
111112

112113

113114
## 📦 Models and encoders <a name="models-and-encoders"></a>
114115

115116
### Architectures <a name="architectures"></a>
116-
- Unet [[paper](https://arxiv.org/abs/1505.04597)] [[docs](https://smp.readthedocs.io/en/latest/models.html#unet)]
117-
- Unet++ [[paper](https://arxiv.org/pdf/1807.10165.pdf)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id2)]
118-
- MAnet [[paper](https://ieeexplore.ieee.org/abstract/document/9201310)] [[docs](https://smp.readthedocs.io/en/latest/models.html#manet)]
119-
- Linknet [[paper](https://arxiv.org/abs/1707.03718)] [[docs](https://smp.readthedocs.io/en/latest/models.html#linknet)]
120-
- FPN [[paper](http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf)] [[docs](https://smp.readthedocs.io/en/latest/models.html#fpn)]
121-
- PSPNet [[paper](https://arxiv.org/abs/1612.01105)] [[docs](https://smp.readthedocs.io/en/latest/models.html#pspnet)]
122-
- PAN [[paper](https://arxiv.org/abs/1805.10180)] [[docs](https://smp.readthedocs.io/en/latest/models.html#pan)]
123-
- DeepLabV3 [[paper](https://arxiv.org/abs/1706.05587)] [[docs](https://smp.readthedocs.io/en/latest/models.html#deeplabv3)]
124-
- DeepLabV3+ [[paper](https://arxiv.org/abs/1802.02611)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id9)]
125-
- UPerNet [[paper](https://arxiv.org/abs/1807.10221)] [[docs](https://smp.readthedocs.io/en/latest/models.html#upernet)]
126-
- Segformer [[paper](https://arxiv.org/abs/2105.15203)] [[docs](https://smp.readthedocs.io/en/latest/models.html#segformer)]
127-
- DPT [[paper](https://arxiv.org/abs/2103.13413)] [[docs](https://smp.readthedocs.io/en/latest/models.html#dpt)]
117+
| Architecture | Paper | Documentation | Checkpoints |
118+
|--------------|-------|---------------|------------|
119+
| Unet | [paper](https://arxiv.org/abs/1505.04597) | [docs](https://smp.readthedocs.io/en/latest/models.html#unet) | |
120+
| Unet++ | [paper](https://arxiv.org/pdf/1807.10165.pdf) | [docs](https://smp.readthedocs.io/en/latest/models.html#unetplusplus) | |
121+
| MAnet | [paper](https://ieeexplore.ieee.org/abstract/document/9201310) | [docs](https://smp.readthedocs.io/en/latest/models.html#manet) | |
122+
| Linknet | [paper](https://arxiv.org/abs/1707.03718) | [docs](https://smp.readthedocs.io/en/latest/models.html#linknet) | |
123+
| FPN | [paper](http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf) | [docs](https://smp.readthedocs.io/en/latest/models.html#fpn) | |
124+
| PSPNet | [paper](https://arxiv.org/abs/1612.01105) | [docs](https://smp.readthedocs.io/en/latest/models.html#pspnet) | |
125+
| PAN | [paper](https://arxiv.org/abs/1805.10180) | [docs](https://smp.readthedocs.io/en/latest/models.html#pan) | |
126+
| DeepLabV3 | [paper](https://arxiv.org/abs/1706.05587) | [docs](https://smp.readthedocs.io/en/latest/models.html#deeplabv3) | |
127+
| DeepLabV3+ | [paper](https://arxiv.org/abs/1802.02611) | [docs](https://smp.readthedocs.io/en/latest/models.html#deeplabv3plus) | |
128+
| UPerNet | [paper](https://arxiv.org/abs/1807.10221) | [docs](https://smp.readthedocs.io/en/latest/models.html#upernet) | [checkpoints](https://huggingface.co/collections/smp-hub/upernet-67fadcdbe08418c6ea94f768) |
129+
| Segformer | [paper](https://arxiv.org/abs/2105.15203) | [docs](https://smp.readthedocs.io/en/latest/models.html#segformer) | [checkpoints](https://huggingface.co/collections/smp-hub/segformer-6749eb4923dea2c355f29a1f) |
130+
| DPT | [paper](https://arxiv.org/abs/2103.13413) | [docs](https://smp.readthedocs.io/en/latest/models.html#dpt) | [checkpoints](https://huggingface.co/collections/smp-hub/dpt-67f30487327c0599a0c62d68) |
128131

129132
### Encoders <a name="encoders"></a>
130133

Diff for: examples/upernet_inference_pretrained.ipynb

+153
Large diffs are not rendered by default.

Diff for: requirements/required.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
huggingface_hub==0.30.2
22
numpy==2.2.4
3-
pillow==11.2.0
3+
pillow==11.2.1
44
safetensors==0.5.3
55
timm==1.0.15
66
torch==2.6.0

Diff for: scripts/models-conversions/upernet-hf-to-smp.py

+249
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
import re
2+
import torch
3+
import albumentations as A
4+
import segmentation_models_pytorch as smp
5+
from huggingface_hub import hf_hub_download, HfApi
6+
from collections import defaultdict
7+
8+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9+
10+
# fmt: off
11+
CONVNEXT_MAPPING = {
12+
r"backbone.embeddings.patch_embeddings.(weight|bias)": r"encoder.model.stem_0.\1",
13+
r"backbone.embeddings.layernorm.(weight|bias)": r"encoder.model.stem_1.\1",
14+
r"backbone.encoder.stages.(\d+).layers.(\d+).layer_scale_parameter": r"encoder.model.stages_\1.blocks.\2.gamma",
15+
r"backbone.encoder.stages.(\d+).layers.(\d+).dwconv.(weight|bias)": r"encoder.model.stages_\1.blocks.\2.conv_dw.\3",
16+
r"backbone.encoder.stages.(\d+).layers.(\d+).layernorm.(weight|bias)": r"encoder.model.stages_\1.blocks.\2.norm.\3",
17+
r"backbone.encoder.stages.(\d+).layers.(\d+).pwconv(\d+).(weight|bias)": r"encoder.model.stages_\1.blocks.\2.mlp.fc\3.\4",
18+
r"backbone.encoder.stages.(\d+).downsampling_layer.(\d+).(weight|bias)": r"encoder.model.stages_\1.downsample.\2.\3",
19+
}
20+
21+
SWIN_MAPPING = {
22+
r"backbone.embeddings.patch_embeddings.projection": r"encoder.model.patch_embed.proj",
23+
r"backbone.embeddings.norm": r"encoder.model.patch_embed.norm",
24+
r"backbone.encoder.layers.(\d+).blocks.(\d+).layernorm_before": r"encoder.model.layers_\1.blocks.\2.norm1",
25+
r"backbone.encoder.layers.(\d+).blocks.(\d+).attention.self.relative_position_bias_table": r"encoder.model.layers_\1.blocks.\2.attn.relative_position_bias_table",
26+
r"backbone.encoder.layers.(\d+).blocks.(\d+).attention.self.(query|key|value)": r"encoder.model.layers_\1.blocks.\2.attn.\3",
27+
r"backbone.encoder.layers.(\d+).blocks.(\d+).attention.output.dense": r"encoder.model.layers_\1.blocks.\2.attn.proj",
28+
r"backbone.encoder.layers.(\d+).blocks.(\d+).layernorm_after": r"encoder.model.layers_\1.blocks.\2.norm2",
29+
r"backbone.encoder.layers.(\d+).blocks.(\d+).intermediate.dense": r"encoder.model.layers_\1.blocks.\2.mlp.fc1",
30+
r"backbone.encoder.layers.(\d+).blocks.(\d+).output.dense": r"encoder.model.layers_\1.blocks.\2.mlp.fc2",
31+
r"backbone.encoder.layers.(\d+).downsample.reduction": lambda x: f"encoder.model.layers_{1 + int(x.group(1))}.downsample.reduction",
32+
r"backbone.encoder.layers.(\d+).downsample.norm": lambda x: f"encoder.model.layers_{1 + int(x.group(1))}.downsample.norm",
33+
}
34+
35+
DECODER_MAPPING = {
36+
37+
# started from 1 in hf
38+
r"backbone.hidden_states_norms.stage(\d+)": lambda x: f"decoder.feature_norms.{int(x.group(1)) - 1}",
39+
40+
r"decode_head.psp_modules.(\d+).(\d+).conv.weight": r"decoder.psp.blocks.\1.\2.0.weight",
41+
r"decode_head.psp_modules.(\d+).(\d+).batch_norm": r"decoder.psp.blocks.\1.\2.1",
42+
r"decode_head.bottleneck.conv.weight": r"decoder.psp.out_conv.0.weight",
43+
r"decode_head.bottleneck.batch_norm": r"decoder.psp.out_conv.1",
44+
45+
# fpn blocks are in reverse order (3 blocks total, so 2 - i)
46+
r"decode_head.lateral_convs.(\d+).conv.weight": lambda x: f"decoder.fpn_lateral_blocks.{2 - int(x.group(1))}.conv_norm_relu.0.weight",
47+
r"decode_head.lateral_convs.(\d+).batch_norm": lambda x: f"decoder.fpn_lateral_blocks.{2 - int(x.group(1))}.conv_norm_relu.1",
48+
r"decode_head.fpn_convs.(\d+).conv.weight": lambda x: f"decoder.fpn_conv_blocks.{2 - int(x.group(1))}.0.weight",
49+
r"decode_head.fpn_convs.(\d+).batch_norm": lambda x: f"decoder.fpn_conv_blocks.{2 - int(x.group(1))}.1",
50+
51+
r"decode_head.fpn_bottleneck.conv.weight": r"decoder.fusion_block.0.weight",
52+
r"decode_head.fpn_bottleneck.batch_norm": r"decoder.fusion_block.1",
53+
r"decode_head.classifier": r"segmentation_head.0",
54+
}
55+
# fmt: on
56+
57+
PRETRAINED_CHECKPOINTS = {
58+
"convnext-tiny": {
59+
"repo_id": "openmmlab/upernet-convnext-tiny",
60+
"encoder_name": "tu-convnext_tiny",
61+
"decoder_channels": 512,
62+
"classes": 150,
63+
"mapping": {**CONVNEXT_MAPPING, **DECODER_MAPPING},
64+
},
65+
"convnext-small": {
66+
"repo_id": "openmmlab/upernet-convnext-small",
67+
"encoder_name": "tu-convnext_small",
68+
"decoder_channels": 512,
69+
"classes": 150,
70+
"mapping": {**CONVNEXT_MAPPING, **DECODER_MAPPING},
71+
},
72+
"convnext-base": {
73+
"repo_id": "openmmlab/upernet-convnext-base",
74+
"encoder_name": "tu-convnext_base",
75+
"decoder_channels": 512,
76+
"classes": 150,
77+
"mapping": {**CONVNEXT_MAPPING, **DECODER_MAPPING},
78+
},
79+
"convnext-large": {
80+
"repo_id": "openmmlab/upernet-convnext-large",
81+
"encoder_name": "tu-convnext_large",
82+
"decoder_channels": 512,
83+
"classes": 150,
84+
"mapping": {**CONVNEXT_MAPPING, **DECODER_MAPPING},
85+
},
86+
"convnext-xlarge": {
87+
"repo_id": "openmmlab/upernet-convnext-xlarge",
88+
"encoder_name": "tu-convnext_xlarge",
89+
"decoder_channels": 512,
90+
"classes": 150,
91+
"mapping": {**CONVNEXT_MAPPING, **DECODER_MAPPING},
92+
},
93+
"swin-tiny": {
94+
"repo_id": "openmmlab/upernet-swin-tiny",
95+
"encoder_name": "tu-swin_tiny_patch4_window7_224",
96+
"decoder_channels": 512,
97+
"classes": 150,
98+
"extra_kwargs": {"img_size": 512},
99+
"mapping": {**SWIN_MAPPING, **DECODER_MAPPING},
100+
},
101+
"swin-small": {
102+
"repo_id": "openmmlab/upernet-swin-small",
103+
"encoder_name": "tu-swin_small_patch4_window7_224",
104+
"decoder_channels": 512,
105+
"classes": 150,
106+
"extra_kwargs": {"img_size": 512},
107+
"mapping": {**SWIN_MAPPING, **DECODER_MAPPING},
108+
},
109+
"swin-large": {
110+
"repo_id": "openmmlab/upernet-swin-large",
111+
"encoder_name": "tu-swin_large_patch4_window12_384",
112+
"decoder_channels": 512,
113+
"classes": 150,
114+
"extra_kwargs": {"img_size": 512},
115+
"mapping": {**SWIN_MAPPING, **DECODER_MAPPING},
116+
},
117+
}
118+
119+
120+
def convert_old_keys_to_new_keys(state_dict_keys: dict, keys_mapping: dict):
121+
"""
122+
This function should be applied only once, on the concatenated keys to efficiently rename using
123+
the key mappings.
124+
"""
125+
output_dict = {}
126+
if state_dict_keys is not None:
127+
old_text = "\n".join(state_dict_keys)
128+
new_text = old_text
129+
for pattern, replacement in keys_mapping.items():
130+
if replacement is None:
131+
new_text = re.sub(pattern, "", new_text) # an empty line
132+
continue
133+
new_text = re.sub(pattern, replacement, new_text)
134+
output_dict = dict(zip(old_text.split("\n"), new_text.split("\n")))
135+
return output_dict
136+
137+
138+
def group_qkv_layers(state_dict: dict) -> dict:
139+
"""Find corresponding layer names for query, key and value layers and stack them in a single layer"""
140+
141+
state_dict = state_dict.copy() # shallow copy
142+
143+
result = defaultdict(dict)
144+
layer_names = list(state_dict.keys())
145+
qkv_names = ["query", "key", "value"]
146+
for layer_name in layer_names:
147+
for pattern in qkv_names:
148+
if pattern in layer_name:
149+
new_key = layer_name.replace(pattern, "qkv")
150+
result[new_key][pattern] = state_dict.pop(layer_name)
151+
break
152+
153+
# merge them all
154+
for new_key, patterns in result.items():
155+
state_dict[new_key] = torch.cat(
156+
[patterns[qkv_name] for qkv_name in qkv_names], dim=0
157+
)
158+
159+
return state_dict
160+
161+
162+
def convert_model(model_name: str, push_to_hub: bool = False):
163+
params = PRETRAINED_CHECKPOINTS[model_name]
164+
165+
print(f"Converting model: {model_name}")
166+
print(f"Downloading weights from: {params['repo_id']}")
167+
168+
hf_weights_path = hf_hub_download(
169+
repo_id=params["repo_id"], filename="pytorch_model.bin"
170+
)
171+
hf_state_dict = torch.load(hf_weights_path, weights_only=True)
172+
print(f"Loaded HuggingFace state dict with {len(hf_state_dict)} keys")
173+
174+
# Rename keys
175+
keys_mapping = convert_old_keys_to_new_keys(hf_state_dict.keys(), params["mapping"])
176+
177+
smp_state_dict = {}
178+
for old_key, new_key in keys_mapping.items():
179+
smp_state_dict[new_key] = hf_state_dict[old_key]
180+
181+
# remove aux head
182+
smp_state_dict = {
183+
k: v for k, v in smp_state_dict.items() if "auxiliary_head." not in k
184+
}
185+
186+
# [swin] group qkv layers and remove `relative_position_index`
187+
smp_state_dict = group_qkv_layers(smp_state_dict)
188+
smp_state_dict = {
189+
k: v for k, v in smp_state_dict.items() if "relative_position_index" not in k
190+
}
191+
192+
# Create model
193+
print(f"Creating SMP UPerNet model with encoder: {params['encoder_name']}")
194+
extra_kwargs = params.get("extra_kwargs", {})
195+
smp_model = smp.UPerNet(
196+
encoder_name=params["encoder_name"],
197+
encoder_weights=None,
198+
decoder_channels=params["decoder_channels"],
199+
classes=params["classes"],
200+
**extra_kwargs,
201+
)
202+
203+
print("Loading weights into SMP model...")
204+
smp_model.load_state_dict(smp_state_dict, strict=True)
205+
206+
# Check we can run the model
207+
print("Verifying model with test inference...")
208+
smp_model.eval()
209+
sample = torch.ones(1, 3, 512, 512)
210+
with torch.no_grad():
211+
output = smp_model(sample)
212+
print(f"Test inference successful. Output shape: {output.shape}")
213+
214+
# Save model with preprocessing
215+
smp_repo_id = f"smp-hub/upernet-{model_name}"
216+
print(f"Saving model to: {smp_repo_id}")
217+
smp_model.save_pretrained(save_directory=smp_repo_id)
218+
219+
transform = A.Compose(
220+
[
221+
A.Resize(512, 512),
222+
A.Normalize(
223+
mean=(123.675, 116.28, 103.53),
224+
std=(58.395, 57.12, 57.375),
225+
max_pixel_value=1.0,
226+
),
227+
]
228+
)
229+
transform.save_pretrained(save_directory=smp_repo_id)
230+
231+
if push_to_hub:
232+
print(f"Pushing model to HuggingFace Hub: {smp_repo_id}")
233+
api = HfApi()
234+
if not api.repo_exists(smp_repo_id):
235+
api.create_repo(repo_id=smp_repo_id, repo_type="model")
236+
api.upload_folder(
237+
repo_id=smp_repo_id,
238+
folder_path=smp_repo_id,
239+
repo_type="model",
240+
)
241+
242+
print(f"Conversion of {model_name} completed successfully!")
243+
244+
245+
if __name__ == "__main__":
246+
print(f"Starting conversion of {len(PRETRAINED_CHECKPOINTS)} UPerNet models")
247+
for model_name in PRETRAINED_CHECKPOINTS.keys():
248+
convert_model(model_name, push_to_hub=True)
249+
print("All conversions completed!")

Diff for: segmentation_models_pytorch/base/initialization.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ def initialize_decoder(module):
88
if m.bias is not None:
99
nn.init.constant_(m.bias, 0)
1010

11-
elif isinstance(m, nn.BatchNorm2d):
11+
elif isinstance(
12+
m, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm, nn.InstanceNorm2d)
13+
):
1214
nn.init.constant_(m.weight, 1)
1315
nn.init.constant_(m.bias, 0)
1416

Diff for: segmentation_models_pytorch/base/model.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,7 @@ def predict(self, x):
8787
"""
8888
if self.training:
8989
self.eval()
90-
91-
x = self.forward(x)
92-
90+
x = self(x)
9391
return x
9492

9593
def load_state_dict(self, state_dict, **kwargs):

0 commit comments

Comments
 (0)