-
Notifications
You must be signed in to change notification settings - Fork 283
Add Deeplabv3Plus and DeepLabV3 with segmentation #1869
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
divyashreepathihalli
merged 15 commits into
keras-team:master
from
sachinprasadhs:deeplab
Oct 3, 2024
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
8683ead
Add Deeplab and DeepLabV3 with segmentation
sachinprasadhs 78fe47c
Merge branch 'master' into deeplab
sachinprasadhs 5540e20
address comments
sachinprasadhs f48de53
test fix
sachinprasadhs 24a4198
Merge 'upstream/master' into deeplab
sachinprasadhs 0fd1708
update copyright
sachinprasadhs 19b0d44
add preprocessor
sachinprasadhs 1367419
Merge branch 'master' into deeplab
sachinprasadhs a6880fa
fix task test for deeplab
sachinprasadhs ee097a3
Merge branch 'master' into deeplab
sachinprasadhs 553702c
nit
sachinprasadhs baeb683
Merge remote-tracking branch 'upstream/master' into deeplab
sachinprasadhs d5c60a2
modify preprocessor for masks
sachinprasadhs 608c4c6
fix image shape to provide tuple
sachinprasadhs 661390a
nit
sachinprasadhs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( | ||
DeepLabV3Backbone, | ||
) | ||
from keras_hub.src.models.deeplab_v3.deeplab_v3_presets import backbone_presets | ||
from keras_hub.src.utils.preset_utils import register_presets | ||
|
||
register_presets(backbone_presets, DeepLabV3Backbone) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
import keras | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.backbone import Backbone | ||
from keras_hub.src.models.deeplab_v3.deeplab_v3_layers import ( | ||
SpatialPyramidPooling, | ||
) | ||
|
||
|
||
@keras_hub_export("keras_hub.models.DeepLabV3Backbone") | ||
class DeepLabV3Backbone(Backbone): | ||
"""DeepLabV3 & DeepLabV3Plus architecture for semantic segmentation. | ||
|
||
This class implements a DeepLabV3 & DeepLabV3Plus architecture as described | ||
in [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation]( | ||
https://arxiv.org/abs/1802.02611)(ECCV 2018) | ||
and [Rethinking Atrous Convolution for Semantic Image Segmentation]( | ||
https://arxiv.org/abs/1706.05587)(CVPR 2017) | ||
|
||
Args: | ||
image_encoder: `keras.Model`. An instance that is used as a feature | ||
extractor for the Encoder. Should either be a | ||
`keras_hub.models.Backbone` or a `keras.Model` that implements the | ||
`pyramid_outputs` property with keys "P2", "P3" etc as values. | ||
A somewhat sensible backbone to use in many cases is | ||
the `keras_hub.models.ResNetBackbone.from_preset("resnet_v2_50")`. | ||
projection_filters: int. Number of filters in the convolution layer | ||
projecting low-level features from the `image_encoder`. | ||
spatial_pyramid_pooling_key: str. A layer level to extract and perform | ||
`spatial_pyramid_pooling`, one of the key from the `image_encoder` | ||
`pyramid_outputs` property such as "P4", "P5" etc. | ||
upsampling_size: int or tuple of 2 integers. The upsampling factors for | ||
rows and columns of `spatial_pyramid_pooling` layer. | ||
If `low_level_feature_key` is given then `spatial_pyramid_pooling`s | ||
layer resolution should match with the `low_level_feature`s layer | ||
resolution to concatenate both the layers for combined encoder | ||
outputs. | ||
dilation_rates: list. A `list` of integers for parallel dilated conv applied to | ||
`SpatialPyramidPooling`. Usually a | ||
sample choice of rates are `[6, 12, 18]`. | ||
low_level_feature_key: str optional. A layer level to extract the feature | ||
from one of the key from the `image_encoder`s `pyramid_outputs` | ||
property such as "P2", "P3" etc which will be the Decoder block. | ||
Required only when the DeepLabV3Plus architecture needs to be applied. | ||
image_shape: tuple. The input shape without the batch size. | ||
Defaults to `(None, None, 3)`. | ||
|
||
Example: | ||
```python | ||
# Load a trained backbone to extract features from it's `pyramid_outputs`. | ||
image_encoder = keras_hub.models.ResNetBackbone.from_preset("resnet_50_imagenet") | ||
|
||
model = keras_hub.models.DeepLabV3Backbone( | ||
image_encoder=image_encoder, | ||
projection_filters=48, | ||
low_level_feature_key="P2", | ||
spatial_pyramid_pooling_key="P5", | ||
upsampling_size = 8, | ||
dilation_rates = [6, 12, 18] | ||
) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
image_encoder, | ||
spatial_pyramid_pooling_key, | ||
upsampling_size, | ||
dilation_rates, | ||
low_level_feature_key=None, | ||
projection_filters=48, | ||
image_shape=(None, None, 3), | ||
**kwargs, | ||
): | ||
if not isinstance(image_encoder, keras.Model): | ||
raise ValueError( | ||
"Argument `image_encoder` must be a `keras.Model` instance. Received instead " | ||
f"{image_encoder} (of type {type(image_encoder)})." | ||
) | ||
data_format = keras.config.image_data_format() | ||
channel_axis = -1 if data_format == "channels_last" else 1 | ||
|
||
# === Layers === | ||
inputs = keras.layers.Input(image_shape, name="inputs") | ||
|
||
fpn_model = keras.Model( | ||
image_encoder.inputs, image_encoder.pyramid_outputs | ||
) | ||
|
||
fpn_outputs = fpn_model(inputs) | ||
|
||
spatial_pyramid_pooling = SpatialPyramidPooling( | ||
sachinprasadhs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dilation_rates=dilation_rates | ||
) | ||
spatial_backbone_features = fpn_outputs[spatial_pyramid_pooling_key] | ||
spp_outputs = spatial_pyramid_pooling(spatial_backbone_features) | ||
|
||
encoder_outputs = keras.layers.UpSampling2D( | ||
size=upsampling_size, | ||
interpolation="bilinear", | ||
name="encoder_output_upsampling", | ||
data_format=data_format, | ||
)(spp_outputs) | ||
|
||
if low_level_feature_key: | ||
decoder_feature = fpn_outputs[low_level_feature_key] | ||
low_level_projected_features = apply_low_level_feature_network( | ||
decoder_feature, projection_filters, channel_axis | ||
) | ||
|
||
encoder_outputs = keras.layers.Concatenate( | ||
axis=channel_axis, name="encoder_decoder_concat" | ||
)([encoder_outputs, low_level_projected_features]) | ||
# upsampling to the original image size | ||
upsampling = (2 ** int(spatial_pyramid_pooling_key[-1])) // ( | ||
int(upsampling_size[0]) | ||
if isinstance(upsampling_size, tuple) | ||
else upsampling_size | ||
) | ||
# === Functional Model === | ||
x = keras.layers.Conv2D( | ||
name="segmentation_head_conv", | ||
filters=256, | ||
kernel_size=1, | ||
padding="same", | ||
use_bias=False, | ||
data_format=data_format, | ||
)(encoder_outputs) | ||
x = keras.layers.BatchNormalization( | ||
name="segmentation_head_norm", axis=channel_axis | ||
)(x) | ||
x = keras.layers.ReLU(name="segmentation_head_relu")(x) | ||
x = keras.layers.UpSampling2D( | ||
size=upsampling, | ||
interpolation="bilinear", | ||
data_format=data_format, | ||
name="backbone_output_upsampling", | ||
)(x) | ||
|
||
super().__init__(inputs=inputs, outputs=x, **kwargs) | ||
|
||
# === Config === | ||
self.image_shape = image_shape | ||
self.image_encoder = image_encoder | ||
self.projection_filters = projection_filters | ||
self.upsampling_size = upsampling_size | ||
self.dilation_rates = dilation_rates | ||
self.low_level_feature_key = low_level_feature_key | ||
self.spatial_pyramid_pooling_key = spatial_pyramid_pooling_key | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"image_encoder": keras.saving.serialize_keras_object( | ||
self.image_encoder | ||
), | ||
"projection_filters": self.projection_filters, | ||
"dilation_rates": self.dilation_rates, | ||
"upsampling_size": self.upsampling_size, | ||
"low_level_feature_key": self.low_level_feature_key, | ||
"spatial_pyramid_pooling_key": self.spatial_pyramid_pooling_key, | ||
"image_shape": self.image_shape, | ||
} | ||
) | ||
return config | ||
|
||
@classmethod | ||
def from_config(cls, config): | ||
if "image_encoder" in config and isinstance( | ||
config["image_encoder"], dict | ||
): | ||
config["image_encoder"] = keras.layers.deserialize( | ||
config["image_encoder"] | ||
) | ||
return super().from_config(config) | ||
|
||
|
||
def apply_low_level_feature_network( | ||
input_tensor, projection_filters, channel_axis | ||
): | ||
data_format = keras.config.image_data_format() | ||
x = keras.layers.Conv2D( | ||
name="decoder_conv", | ||
filters=projection_filters, | ||
kernel_size=1, | ||
padding="same", | ||
use_bias=False, | ||
data_format=data_format, | ||
)(input_tensor) | ||
|
||
x = keras.layers.BatchNormalization(name="decoder_norm", axis=channel_axis)( | ||
x | ||
) | ||
x = keras.layers.ReLU(name="decoder_relu")(x) | ||
return x |
73 changes: 73 additions & 0 deletions
73
keras_hub/src/models/deeplab_v3/deeplab_v3_backbone_test.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import keras | ||
import numpy as np | ||
import pytest | ||
|
||
from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( | ||
DeepLabV3Backbone, | ||
) | ||
from keras_hub.src.models.deeplab_v3.deeplab_v3_layers import ( | ||
SpatialPyramidPooling, | ||
) | ||
from keras_hub.src.models.resnet.resnet_backbone import ResNetBackbone | ||
from keras_hub.src.tests.test_case import TestCase | ||
|
||
|
||
class DeepLabV3Test(TestCase): | ||
def setUp(self): | ||
self.resnet_kwargs = { | ||
"input_conv_filters": [64], | ||
"input_conv_kernel_sizes": [7], | ||
"stackwise_num_filters": [64, 64, 64], | ||
"stackwise_num_blocks": [2, 2, 2], | ||
"stackwise_num_strides": [1, 2, 2], | ||
"block_type": "basic_block", | ||
"use_pre_activation": False, | ||
} | ||
self.image_encoder = ResNetBackbone(**self.resnet_kwargs) | ||
self.init_kwargs = { | ||
"image_encoder": self.image_encoder, | ||
"low_level_feature_key": "P2", | ||
"spatial_pyramid_pooling_key": "P4", | ||
"dilation_rates": [6, 12, 18], | ||
"upsampling_size": 4, | ||
"image_shape": (96, 96, 3), | ||
} | ||
self.input_data = np.ones((2, 96, 96, 3), dtype="float32") | ||
|
||
def test_segmentation_basics(self): | ||
self.run_vision_backbone_test( | ||
cls=DeepLabV3Backbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
expected_output_shape=(2, 96, 96, 256), | ||
run_mixed_precision_check=False, | ||
run_quantization_check=False, | ||
run_data_format_check=False, | ||
) | ||
|
||
@pytest.mark.large | ||
def test_saved_model(self): | ||
self.run_model_saving_test( | ||
cls=DeepLabV3Backbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
) | ||
|
||
|
||
class SpatialPyramidPoolingTest(TestCase): | ||
def test_layer_behaviors(self): | ||
self.run_layer_test( | ||
cls=SpatialPyramidPooling, | ||
init_kwargs={ | ||
"dilation_rates": [6, 12, 18], | ||
"activation": "relu", | ||
"num_channels": 256, | ||
"dropout": 0.1, | ||
}, | ||
input_data=keras.random.uniform(shape=(1, 4, 4, 6)), | ||
expected_output_shape=(1, 4, 4, 256), | ||
expected_num_trainable_weights=18, | ||
expected_num_non_trainable_variables=13, | ||
expected_num_non_trainable_weights=12, | ||
run_precision_checks=False, | ||
) |
10 changes: 10 additions & 0 deletions
10
keras_hub/src/models/deeplab_v3/deeplab_v3_image_converter.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter | ||
from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( | ||
DeepLabV3Backbone, | ||
) | ||
|
||
|
||
@keras_hub_export("keras_hub.layers.DeepLabV3ImageConverter") | ||
class DeepLabV3ImageConverter(ImageConverter): | ||
backbone_cls = DeepLabV3Backbone |
16 changes: 16 additions & 0 deletions
16
keras_hub/src/models/deeplab_v3/deeplab_v3_image_segmeter_preprocessor.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.deeplab_v3.deeplab_v3_backbone import ( | ||
DeepLabV3Backbone, | ||
) | ||
from keras_hub.src.models.deeplab_v3.deeplab_v3_image_converter import ( | ||
DeepLabV3ImageConverter, | ||
) | ||
from keras_hub.src.models.image_segmenter_preprocessor import ( | ||
ImageSegmenterPreprocessor, | ||
) | ||
|
||
|
||
@keras_hub_export("keras_hub.models.DeepLabV3ImageSegmenterPreprocessor") | ||
class DeepLabV3ImageSegmenterPreprocessor(ImageSegmenterPreprocessor): | ||
backbone_cls = DeepLabV3Backbone | ||
image_converter_cls = DeepLabV3ImageConverter |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.