Skip to content

Commit 6f35ed1

Browse files
Add Deeplabv3Plus and DeepLabV3 with segmentation (#1869)
* Add Deeplab and DeepLabV3 with segmentation * address comments * test fix * update copyright * add preprocessor * fix task test for deeplab * nit * modify preprocessor for masks * fix image shape to provide tuple * nit
1 parent 11227f3 commit 6f35ed1

12 files changed

+743
-4
lines changed

keras_hub/api/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
from keras_hub.src.layers.preprocessing.random_deletion import RandomDeletion
3535
from keras_hub.src.layers.preprocessing.random_swap import RandomSwap
3636
from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
37+
from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import (
38+
DeepLabV3ImageConverter,
39+
)
3740
from keras_hub.src.models.densenet.densenet_image_converter import (
3841
DenseNetImageConverter,
3942
)

keras_hub/api/models/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,15 @@
8585
from keras_hub.src.models.deberta_v3.deberta_v3_tokenizer import (
8686
DebertaV3Tokenizer,
8787
)
88+
from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import (
89+
DeepLabV3Backbone,
90+
)
91+
from keras_hub.src.models.deeplab_v3.deeplab_v3_image_segmeter_preprocessor import (
92+
DeepLabV3ImageSegmenterPreprocessor,
93+
)
94+
from keras_hub.src.models.deeplab_v3.deeplab_v3_segmenter import (
95+
DeepLabV3ImageSegmenter,
96+
)
8897
from keras_hub.src.models.densenet.densenet_backbone import DenseNetBackbone
8998
from keras_hub.src.models.densenet.densenet_image_classifier import (
9099
DenseNetImageClassifier,
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import (
2+
DeepLabV3Backbone,
3+
)
4+
from keras_hub.src.models.deeplab_v3.deeplab_v3_presets import backbone_presets
5+
from keras_hub.src.utils.preset_utils import register_presets
6+
7+
register_presets(backbone_presets, DeepLabV3Backbone)
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import keras
2+
3+
from keras_hub.src.api_export import keras_hub_export
4+
from keras_hub.src.models.backbone import Backbone
5+
from keras_hub.src.models.deeplab_v3.deeplab_v3_layers import (
6+
SpatialPyramidPooling,
7+
)
8+
9+
10+
@keras_hub_export("keras_hub.models.DeepLabV3Backbone")
11+
class DeepLabV3Backbone(Backbone):
12+
"""DeepLabV3 & DeepLabV3Plus architecture for semantic segmentation.
13+
14+
This class implements a DeepLabV3 & DeepLabV3Plus architecture as described
15+
in [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](
16+
https://arxiv.org/abs/1802.02611)(ECCV 2018)
17+
and [Rethinking Atrous Convolution for Semantic Image Segmentation](
18+
https://arxiv.org/abs/1706.05587)(CVPR 2017)
19+
20+
Args:
21+
image_encoder: `keras.Model`. An instance that is used as a feature
22+
extractor for the Encoder. Should either be a
23+
`keras_hub.models.Backbone` or a `keras.Model` that implements the
24+
`pyramid_outputs` property with keys "P2", "P3" etc as values.
25+
A somewhat sensible backbone to use in many cases is
26+
the `keras_hub.models.ResNetBackbone.from_preset("resnet_v2_50")`.
27+
projection_filters: int. Number of filters in the convolution layer
28+
projecting low-level features from the `image_encoder`.
29+
spatial_pyramid_pooling_key: str. A layer level to extract and perform
30+
`spatial_pyramid_pooling`, one of the key from the `image_encoder`
31+
`pyramid_outputs` property such as "P4", "P5" etc.
32+
upsampling_size: int or tuple of 2 integers. The upsampling factors for
33+
rows and columns of `spatial_pyramid_pooling` layer.
34+
If `low_level_feature_key` is given then `spatial_pyramid_pooling`s
35+
layer resolution should match with the `low_level_feature`s layer
36+
resolution to concatenate both the layers for combined encoder
37+
outputs.
38+
dilation_rates: list. A `list` of integers for parallel dilated conv applied to
39+
`SpatialPyramidPooling`. Usually a
40+
sample choice of rates are `[6, 12, 18]`.
41+
low_level_feature_key: str optional. A layer level to extract the feature
42+
from one of the key from the `image_encoder`s `pyramid_outputs`
43+
property such as "P2", "P3" etc which will be the Decoder block.
44+
Required only when the DeepLabV3Plus architecture needs to be applied.
45+
image_shape: tuple. The input shape without the batch size.
46+
Defaults to `(None, None, 3)`.
47+
48+
Example:
49+
```python
50+
# Load a trained backbone to extract features from it's `pyramid_outputs`.
51+
image_encoder = keras_hub.models.ResNetBackbone.from_preset("resnet_50_imagenet")
52+
53+
model = keras_hub.models.DeepLabV3Backbone(
54+
image_encoder=image_encoder,
55+
projection_filters=48,
56+
low_level_feature_key="P2",
57+
spatial_pyramid_pooling_key="P5",
58+
upsampling_size = 8,
59+
dilation_rates = [6, 12, 18]
60+
)
61+
```
62+
"""
63+
64+
def __init__(
65+
self,
66+
image_encoder,
67+
spatial_pyramid_pooling_key,
68+
upsampling_size,
69+
dilation_rates,
70+
low_level_feature_key=None,
71+
projection_filters=48,
72+
image_shape=(None, None, 3),
73+
**kwargs,
74+
):
75+
if not isinstance(image_encoder, keras.Model):
76+
raise ValueError(
77+
"Argument `image_encoder` must be a `keras.Model` instance. Received instead "
78+
f"{image_encoder} (of type {type(image_encoder)})."
79+
)
80+
data_format = keras.config.image_data_format()
81+
channel_axis = -1 if data_format == "channels_last" else 1
82+
83+
# === Layers ===
84+
inputs = keras.layers.Input(image_shape, name="inputs")
85+
86+
fpn_model = keras.Model(
87+
image_encoder.inputs, image_encoder.pyramid_outputs
88+
)
89+
90+
fpn_outputs = fpn_model(inputs)
91+
92+
spatial_pyramid_pooling = SpatialPyramidPooling(
93+
dilation_rates=dilation_rates
94+
)
95+
spatial_backbone_features = fpn_outputs[spatial_pyramid_pooling_key]
96+
spp_outputs = spatial_pyramid_pooling(spatial_backbone_features)
97+
98+
encoder_outputs = keras.layers.UpSampling2D(
99+
size=upsampling_size,
100+
interpolation="bilinear",
101+
name="encoder_output_upsampling",
102+
data_format=data_format,
103+
)(spp_outputs)
104+
105+
if low_level_feature_key:
106+
decoder_feature = fpn_outputs[low_level_feature_key]
107+
low_level_projected_features = apply_low_level_feature_network(
108+
decoder_feature, projection_filters, channel_axis
109+
)
110+
111+
encoder_outputs = keras.layers.Concatenate(
112+
axis=channel_axis, name="encoder_decoder_concat"
113+
)([encoder_outputs, low_level_projected_features])
114+
# upsampling to the original image size
115+
upsampling = (2 ** int(spatial_pyramid_pooling_key[-1])) // (
116+
int(upsampling_size[0])
117+
if isinstance(upsampling_size, tuple)
118+
else upsampling_size
119+
)
120+
# === Functional Model ===
121+
x = keras.layers.Conv2D(
122+
name="segmentation_head_conv",
123+
filters=256,
124+
kernel_size=1,
125+
padding="same",
126+
use_bias=False,
127+
data_format=data_format,
128+
)(encoder_outputs)
129+
x = keras.layers.BatchNormalization(
130+
name="segmentation_head_norm", axis=channel_axis
131+
)(x)
132+
x = keras.layers.ReLU(name="segmentation_head_relu")(x)
133+
x = keras.layers.UpSampling2D(
134+
size=upsampling,
135+
interpolation="bilinear",
136+
data_format=data_format,
137+
name="backbone_output_upsampling",
138+
)(x)
139+
140+
super().__init__(inputs=inputs, outputs=x, **kwargs)
141+
142+
# === Config ===
143+
self.image_shape = image_shape
144+
self.image_encoder = image_encoder
145+
self.projection_filters = projection_filters
146+
self.upsampling_size = upsampling_size
147+
self.dilation_rates = dilation_rates
148+
self.low_level_feature_key = low_level_feature_key
149+
self.spatial_pyramid_pooling_key = spatial_pyramid_pooling_key
150+
151+
def get_config(self):
152+
config = super().get_config()
153+
config.update(
154+
{
155+
"image_encoder": keras.saving.serialize_keras_object(
156+
self.image_encoder
157+
),
158+
"projection_filters": self.projection_filters,
159+
"dilation_rates": self.dilation_rates,
160+
"upsampling_size": self.upsampling_size,
161+
"low_level_feature_key": self.low_level_feature_key,
162+
"spatial_pyramid_pooling_key": self.spatial_pyramid_pooling_key,
163+
"image_shape": self.image_shape,
164+
}
165+
)
166+
return config
167+
168+
@classmethod
169+
def from_config(cls, config):
170+
if "image_encoder" in config and isinstance(
171+
config["image_encoder"], dict
172+
):
173+
config["image_encoder"] = keras.layers.deserialize(
174+
config["image_encoder"]
175+
)
176+
return super().from_config(config)
177+
178+
179+
def apply_low_level_feature_network(
180+
input_tensor, projection_filters, channel_axis
181+
):
182+
data_format = keras.config.image_data_format()
183+
x = keras.layers.Conv2D(
184+
name="decoder_conv",
185+
filters=projection_filters,
186+
kernel_size=1,
187+
padding="same",
188+
use_bias=False,
189+
data_format=data_format,
190+
)(input_tensor)
191+
192+
x = keras.layers.BatchNormalization(name="decoder_norm", axis=channel_axis)(
193+
x
194+
)
195+
x = keras.layers.ReLU(name="decoder_relu")(x)
196+
return x
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import keras
2+
import numpy as np
3+
import pytest
4+
5+
from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import (
6+
DeepLabV3Backbone,
7+
)
8+
from keras_hub.src.models.deeplab_v3.deeplab_v3_layers import (
9+
SpatialPyramidPooling,
10+
)
11+
from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone
12+
from keras_hub.src.tests.test_case import TestCase
13+
14+
15+
class DeepLabV3Test(TestCase):
16+
def setUp(self):
17+
self.resnet_kwargs = {
18+
"input_conv_filters": [64],
19+
"input_conv_kernel_sizes": [7],
20+
"stackwise_num_filters": [64, 64, 64],
21+
"stackwise_num_blocks": [2, 2, 2],
22+
"stackwise_num_strides": [1, 2, 2],
23+
"block_type": "basic_block",
24+
"use_pre_activation": False,
25+
}
26+
self.image_encoder = ResNetBackbone(**self.resnet_kwargs)
27+
self.init_kwargs = {
28+
"image_encoder": self.image_encoder,
29+
"low_level_feature_key": "P2",
30+
"spatial_pyramid_pooling_key": "P4",
31+
"dilation_rates": [6, 12, 18],
32+
"upsampling_size": 4,
33+
"image_shape": (96, 96, 3),
34+
}
35+
self.input_data = np.ones((2, 96, 96, 3), dtype="float32")
36+
37+
def test_segmentation_basics(self):
38+
self.run_vision_backbone_test(
39+
cls=DeepLabV3Backbone,
40+
init_kwargs=self.init_kwargs,
41+
input_data=self.input_data,
42+
expected_output_shape=(2, 96, 96, 256),
43+
run_mixed_precision_check=False,
44+
run_quantization_check=False,
45+
run_data_format_check=False,
46+
)
47+
48+
@pytest.mark.large
49+
def test_saved_model(self):
50+
self.run_model_saving_test(
51+
cls=DeepLabV3Backbone,
52+
init_kwargs=self.init_kwargs,
53+
input_data=self.input_data,
54+
)
55+
56+
57+
class SpatialPyramidPoolingTest(TestCase):
58+
def test_layer_behaviors(self):
59+
self.run_layer_test(
60+
cls=SpatialPyramidPooling,
61+
init_kwargs={
62+
"dilation_rates": [6, 12, 18],
63+
"activation": "relu",
64+
"num_channels": 256,
65+
"dropout": 0.1,
66+
},
67+
input_data=keras.random.uniform(shape=(1, 4, 4, 6)),
68+
expected_output_shape=(1, 4, 4, 256),
69+
expected_num_trainable_weights=18,
70+
expected_num_non_trainable_variables=13,
71+
expected_num_non_trainable_weights=12,
72+
run_precision_checks=False,
73+
)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from keras_hub.src.api_export import keras_hub_export
2+
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
3+
from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import (
4+
DeepLabV3Backbone,
5+
)
6+
7+
8+
@keras_hub_export("keras_hub.layers.DeepLabV3ImageConverter")
9+
class DeepLabV3ImageConverter(ImageConverter):
10+
backbone_cls = DeepLabV3Backbone
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from keras_hub.src.api_export import keras_hub_export
2+
from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import (
3+
DeepLabV3Backbone,
4+
)
5+
from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import (
6+
DeepLabV3ImageConverter,
7+
)
8+
from keras_hub.src.models.image_segmenter_preprocessor import (
9+
ImageSegmenterPreprocessor,
10+
)
11+
12+
13+
@keras_hub_export("keras_hub.models.DeepLabV3ImageSegmenterPreprocessor")
14+
class DeepLabV3ImageSegmenterPreprocessor(ImageSegmenterPreprocessor):
15+
backbone_cls = DeepLabV3Backbone
16+
image_converter_cls = DeepLabV3ImageConverter

0 commit comments

Comments
 (0)