Skip to content

Commit 55da400

Browse files
authored
[Semantic Segmentation] - Add SegFormer Architecture, Weight Conversion Script and Presets (#1883)
* initial commit - tf-based, kcv * porting to keras_hub structure - removing aliases, presets, etc. * enable instantiation of segformer backbone with custom MiT backbone * remove num_classes from backbone * fix input * add imports to __init__ * update preset * update docstrings * add basic tests * remove redundant imports * update docstrings * remove unused import * running api_gen.py * undo refactor of mit * update docstrings * add presets for mit * add standin paths * add presets for segformer backbone * register presets in __init__.py * addressing comments * addressing comments * addressing comments * update most tests * add remaining tests * remove copyright * fix test * override from_config * fix op in overlapping patching and embedding, start adding conversion utils * style * add padding to MiT patchingandembedding * update to support other presets * update conversin script * fix link for b5 * add cityscapes weights * update presets * update presets * update conversion script to make directories * use save_preset * change name of output dir * add preprocessor flow * api gen and add preprocessor to mits * conform to new image classifier style * format * resizing image converter -> ImageConverter * merge mit branch into segformer branch * add preprocessor and converter * address comments * clarify backbone usage * add conversion script * numerical equivalence changes * fix numerical inaccuracies * update conversion script * update conversion script * remove transpose * add preprocessor to segformer class * fix preset path * update test shape * update presets * update test shape * expand docstrings * add rescaling and normalization to preprocessor * remove backbone presets, remove copyrights, remove backbone cls from segmenter * remove copyright and unused import * apply same transformation to masks as input images * fix import * fix shape in tests
1 parent 23a2f22 commit 55da400

17 files changed

+844
-20
lines changed

keras_hub/api/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@
5050
from keras_hub.src.models.sam.sam_image_converter import SAMImageConverter
5151
from keras_hub.src.models.sam.sam_mask_decoder import SAMMaskDecoder
5252
from keras_hub.src.models.sam.sam_prompt_encoder import SAMPromptEncoder
53+
from keras_hub.src.models.segformer.segformer_image_converter import (
54+
SegFormerImageConverter,
55+
)
5356
from keras_hub.src.models.vgg.vgg_image_converter import VGGImageConverter
5457
from keras_hub.src.models.whisper.whisper_audio_converter import (
5558
WhisperAudioConverter,

keras_hub/api/models/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,13 @@
266266
from keras_hub.src.models.sam.sam_image_segmenter_preprocessor import (
267267
SAMImageSegmenterPreprocessor,
268268
)
269+
from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone
270+
from keras_hub.src.models.segformer.segformer_image_segmenter import (
271+
SegFormerImageSegmenter,
272+
)
273+
from keras_hub.src.models.segformer.segformer_image_segmenter_preprocessor import (
274+
SegFormerImageSegmenterPreprocessor,
275+
)
269276
from keras_hub.src.models.seq_2_seq_lm import Seq2SeqLM
270277
from keras_hub.src.models.seq_2_seq_lm_preprocessor import Seq2SeqLMPreprocessor
271278
from keras_hub.src.models.stable_diffusion_3.stable_diffusion_3_backbone import (

keras_hub/src/models/mit/mit_backbone.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
# Licensed under the Apache License, Version 2.0 (the "License");
2+
# you may not use this file except in compliance with the License.
3+
# You may obtain a copy of the License at
4+
#
5+
# https://www.apache.org/licenses/LICENSE-2.0
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
112
import keras
213
import numpy as np
314
from keras import ops
@@ -100,7 +111,7 @@ def __init__(
100111
]
101112
transformer_blocks.append(transformer_block)
102113
cur += depths[i]
103-
layer_norms.append(keras.layers.LayerNormalization())
114+
layer_norms.append(keras.layers.LayerNormalization(epsilon=1e-5))
104115

105116
# === Functional Model ===
106117
image_input = keras.layers.Input(shape=image_shape)

keras_hub/src/models/mit/mit_backbone_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class MiTBackboneTest(TestCase):
99
def setUp(self):
1010
self.init_kwargs = {
1111
"depths": [2, 2],
12-
"image_shape": (16, 16, 3),
12+
"image_shape": (32, 32, 3),
1313
"hidden_dims": [4, 8],
1414
"num_layers": 2,
1515
"blockwise_num_heads": [1, 2],
@@ -18,7 +18,7 @@ def setUp(self):
1818
"patch_sizes": [7, 3],
1919
"strides": [4, 2],
2020
}
21-
self.input_size = 16
21+
self.input_size = 32
2222
self.input_data = np.ones(
2323
(2, self.input_size, self.input_size, 3), dtype="float32"
2424
)
@@ -28,9 +28,9 @@ def test_backbone_basics(self):
2828
cls=MiTBackbone,
2929
init_kwargs=self.init_kwargs,
3030
input_data=self.input_data,
31-
expected_output_shape=(2, 2, 2, 8),
31+
expected_output_shape=(2, 4, 4, 8),
3232
expected_pyramid_output_keys=["P1", "P2"],
33-
expected_pyramid_image_sizes=[(4, 4), (2, 2)],
33+
expected_pyramid_image_sizes=[(8, 8), (4, 4)],
3434
run_quantization_check=False,
3535
run_mixed_precision_check=False,
3636
run_data_format_check=False,

keras_hub/src/models/mit/mit_image_classifier_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
class MiTImageClassifierTest(TestCase):
1010
def setUp(self):
1111
# Setup model.
12-
self.images = np.ones((2, 16, 16, 3), dtype="float32")
12+
self.images = np.ones((2, 32, 32, 3), dtype="float32")
1313
self.labels = [0, 3]
1414
self.backbone = MiTBackbone(
1515
depths=[2, 2, 2, 2],
16-
image_shape=(16, 16, 3),
16+
image_shape=(32, 32, 3),
1717
hidden_dims=[4, 8],
1818
num_layers=2,
1919
blockwise_num_heads=[1, 2],
@@ -40,7 +40,7 @@ def test_classifier_basics(self):
4040
cls=MiTImageClassifier,
4141
init_kwargs=self.init_kwargs,
4242
train_data=self.train_data,
43-
expected_output_shape=(2, 2),
43+
expected_output_shape=(4, 4),
4444
)
4545

4646
@pytest.mark.large

keras_hub/src/models/mit/mit_layers.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -183,20 +183,21 @@ def __init__(self, project_dim, num_heads, sr_ratio):
183183
self.k = keras.layers.Dense(project_dim)
184184
self.v = keras.layers.Dense(project_dim)
185185
self.proj = keras.layers.Dense(project_dim)
186+
self.dropout = keras.layers.Dropout(0.1)
187+
self.proj_drop = keras.layers.Dropout(0.1)
186188

187189
if sr_ratio > 1:
188190
self.sr = keras.layers.Conv2D(
189191
filters=project_dim,
190192
kernel_size=sr_ratio,
191193
strides=sr_ratio,
192-
padding="same",
193194
)
194-
self.norm = keras.layers.LayerNormalization()
195+
self.norm = keras.layers.LayerNormalization(epsilon=1e-5)
195196

196197
def call(self, x):
197198
input_shape = ops.shape(x)
198199
H, W = int(math.sqrt(input_shape[1])), int(math.sqrt(input_shape[1]))
199-
B, C = input_shape[0], input_shape[2]
200+
B, N, C = input_shape[0], input_shape[1], input_shape[2]
200201

201202
q = self.q(x)
202203
q = ops.reshape(
@@ -212,12 +213,11 @@ def call(self, x):
212213

213214
if self.sr_ratio > 1:
214215
x = ops.reshape(
215-
ops.transpose(x, [0, 2, 1]),
216+
x,
216217
(B, H, W, C),
217218
)
218219
x = self.sr(x)
219-
x = ops.reshape(x, [input_shape[0], input_shape[2], -1])
220-
x = ops.transpose(x, [0, 2, 1])
220+
x = ops.reshape(x, [B, -1, C])
221221
x = self.norm(x)
222222

223223
k = self.k(x)
@@ -241,14 +241,16 @@ def call(self, x):
241241

242242
attn = (q @ ops.transpose(k, [0, 1, 3, 2])) * self.scale
243243
attn = ops.nn.softmax(attn, axis=-1)
244+
attn = self.dropout(attn)
244245

245246
attn = attn @ v
246247
attn = ops.reshape(
247248
ops.transpose(attn, [0, 2, 1, 3]),
248-
[input_shape[0], input_shape[1], input_shape[2]],
249+
[B, N, C],
249250
)
250251

251252
x = self.proj(attn)
253+
x = self.proj_drop(x)
252254
return x
253255

254256

keras_hub/src/models/mit/mit_presets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
"official_name": "MiT",
7777
"path": "mit",
7878
},
79-
"kaggle_handle": "kaggle://keras/mit/keras/mit_b5_ade20k_512/1",
79+
"kaggle_handle": "kaggle://keras/mit/keras/mit_b5_ade20k_640/1",
8080
},
8181
"mit_b0_cityscapes_1024": {
8282
"metadata": {

keras_hub/src/models/mit/mix_transformer_backbone_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class MiTBackboneTest(TestCase):
99
def setUp(self):
1010
self.init_kwargs = {
1111
"depths": [2, 2],
12-
"image_shape": (16, 16, 3),
12+
"image_shape": (32, 32, 3),
1313
"hidden_dims": [4, 8],
1414
"num_layers": 2,
1515
"blockwise_num_heads": [1, 2],
@@ -18,7 +18,7 @@ def setUp(self):
1818
"patch_sizes": [7, 3],
1919
"strides": [4, 2],
2020
}
21-
self.input_size = 16
21+
self.input_size = 32
2222
self.input_data = np.ones(
2323
(2, self.input_size, self.input_size, 3), dtype="float32"
2424
)
@@ -28,9 +28,9 @@ def test_backbone_basics(self):
2828
cls=MiTBackbone,
2929
init_kwargs=self.init_kwargs,
3030
input_data=self.input_data,
31-
expected_output_shape=(2, 2, 2, 8),
31+
expected_output_shape=(2, 4, 4, 8),
3232
expected_pyramid_output_keys=["P1", "P2"],
33-
expected_pyramid_image_sizes=[(4, 4), (2, 2)],
33+
expected_pyramid_image_sizes=[(8, 8), (4, 4)],
3434
run_quantization_check=False,
3535
run_mixed_precision_check=False,
3636
run_data_format_check=False,
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from keras_hub.src.models.segformer.segformer_backbone import SegFormerBackbone
2+
from keras_hub.src.models.segformer.segformer_image_segmenter import (
3+
SegFormerImageSegmenter,
4+
)
5+
from keras_hub.src.models.segformer.segformer_presets import presets
6+
from keras_hub.src.utils.preset_utils import register_presets
7+
8+
register_presets(presets, SegFormerImageSegmenter)
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import keras
2+
3+
from keras_hub.src.api_export import keras_hub_export
4+
from keras_hub.src.models.backbone import Backbone
5+
6+
7+
@keras_hub_export("keras_hub.models.SegFormerBackbone")
8+
class SegFormerBackbone(Backbone):
9+
"""A Keras model implementing the SegFormer architecture for semantic segmentation.
10+
11+
This class implements the majority of the SegFormer architecture described in
12+
[SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers]
13+
(https://arxiv.org/abs/2105.15203) and [based on the TensorFlow implementation from DeepVision]
14+
(https://github.com/DavidLandup0/deepvision/tree/main/deepvision/models/segmentation/segformer).
15+
16+
SegFormers are meant to be used with the MixTransformer (MiT) encoder family, and
17+
and use a very lightweight all-MLP decoder head.
18+
19+
The MiT encoder uses a hierarchical transformer which outputs features at multiple scales,
20+
similar to that of the hierarchical outputs typically associated with CNNs.
21+
22+
Args:
23+
image_encoder: `keras.Model`. The backbone network for the model that is
24+
used as a feature extractor for the SegFormer encoder.
25+
Should be used with the MiT backbone model
26+
(`keras_hub.models.MiTBackbone`) which was created
27+
specifically for SegFormers.
28+
num_classes: int, the number of classes for the detection model,
29+
including the background class.
30+
projection_filters: int, number of filters in the
31+
convolution layer projecting the concatenated features into
32+
a segmentation map. Defaults to 256`.
33+
34+
Example:
35+
36+
Using the class with a custom `backbone`:
37+
38+
```python
39+
import keras_hub
40+
41+
backbone = keras_hub.models.MiTBackbone(
42+
depths=[2, 2, 2, 2],
43+
image_shape=(224, 224, 3),
44+
hidden_dims=[32, 64, 160, 256],
45+
num_layers=4,
46+
blockwise_num_heads=[1, 2, 5, 8],
47+
blockwise_sr_ratios=[8, 4, 2, 1],
48+
max_drop_path_rate=0.1,
49+
patch_sizes=[7, 3, 3, 3],
50+
strides=[4, 2, 2, 2],
51+
)
52+
53+
segformer_backbone = keras_hub.models.SegFormerBackbone(image_encoder=backbone, projection_filters=256)
54+
```
55+
56+
Using the class with a preset `backbone`:
57+
58+
```python
59+
import keras_hub
60+
61+
backbone = keras_hub.models.MiTBackbone.from_preset("mit_b0_ade20k_512")
62+
segformer_backbone = keras_hub.models.SegFormerBackbone(image_encoder=backbone, projection_filters=256)
63+
```
64+
65+
"""
66+
67+
def __init__(
68+
self,
69+
image_encoder,
70+
projection_filters,
71+
**kwargs,
72+
):
73+
if not isinstance(image_encoder, keras.layers.Layer) or not isinstance(
74+
image_encoder, keras.Model
75+
):
76+
raise ValueError(
77+
"Argument `image_encoder` must be a `keras.layers.Layer` instance "
78+
f" or `keras.Model`. Received instead "
79+
f"image_encoder={image_encoder} (of type {type(image_encoder)})."
80+
)
81+
82+
# === Layers ===
83+
inputs = keras.layers.Input(shape=image_encoder.input.shape[1:])
84+
85+
self.feature_extractor = keras.Model(
86+
image_encoder.inputs, image_encoder.pyramid_outputs
87+
)
88+
89+
features = self.feature_extractor(inputs)
90+
# Get height and width of level one output
91+
_, height, width, _ = features["P1"].shape
92+
93+
self.mlp_blocks = []
94+
95+
for feature_dim, feature in zip(image_encoder.hidden_dims, features):
96+
self.mlp_blocks.append(
97+
keras.layers.Dense(
98+
projection_filters, name=f"linear_{feature_dim}"
99+
)
100+
)
101+
102+
self.resizing = keras.layers.Resizing(
103+
height, width, interpolation="bilinear"
104+
)
105+
self.concat = keras.layers.Concatenate(axis=-1)
106+
self.linear_fuse = keras.Sequential(
107+
[
108+
keras.layers.Conv2D(
109+
filters=projection_filters, kernel_size=1, use_bias=False
110+
),
111+
keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.9),
112+
keras.layers.Activation("relu"),
113+
]
114+
)
115+
116+
# === Functional Model ===
117+
# Project all multi-level outputs onto
118+
# the same dimensionality and feature map shape
119+
multi_layer_outs = []
120+
for index, (feature_dim, feature) in enumerate(
121+
zip(image_encoder.hidden_dims, features)
122+
):
123+
out = self.mlp_blocks[index](features[feature])
124+
out = self.resizing(out)
125+
multi_layer_outs.append(out)
126+
127+
# Concat now-equal feature maps
128+
concatenated_outs = self.concat(multi_layer_outs[::-1])
129+
130+
# Fuse concatenated features into a segmentation map
131+
seg = self.linear_fuse(concatenated_outs)
132+
133+
super().__init__(
134+
inputs=inputs,
135+
outputs=seg,
136+
**kwargs,
137+
)
138+
139+
# === Config ===
140+
self.projection_filters = projection_filters
141+
self.image_encoder = image_encoder
142+
143+
def get_config(self):
144+
config = super().get_config()
145+
config.update(
146+
{
147+
"projection_filters": self.projection_filters,
148+
"image_encoder": keras.saving.serialize_keras_object(
149+
self.image_encoder
150+
),
151+
}
152+
)
153+
return config
154+
155+
@classmethod
156+
def from_config(cls, config):
157+
if "image_encoder" in config and isinstance(
158+
config["image_encoder"], dict
159+
):
160+
config["image_encoder"] = keras.layers.deserialize(
161+
config["image_encoder"]
162+
)
163+
return super().from_config(config)

0 commit comments

Comments
 (0)