|
| 1 | +import keras |
| 2 | + |
| 3 | +from keras_hub.src.api_export import keras_hub_export |
| 4 | +from keras_hub.src.models.backbone import Backbone |
| 5 | +from keras_hub.src.models.deeplab_v3.deeplab_v3_layers import ( |
| 6 | + SpatialPyramidPooling, |
| 7 | +) |
| 8 | + |
| 9 | + |
| 10 | +@keras_hub_export("keras_hub.models.DeepLabV3Backbone") |
| 11 | +class DeepLabV3Backbone(Backbone): |
| 12 | + """DeepLabV3 & DeepLabV3Plus architecture for semantic segmentation. |
| 13 | +
|
| 14 | + This class implements a DeepLabV3 & DeepLabV3Plus architecture as described |
| 15 | + in [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation]( |
| 16 | + https://arxiv.org/abs/1802.02611)(ECCV 2018) |
| 17 | + and [Rethinking Atrous Convolution for Semantic Image Segmentation]( |
| 18 | + https://arxiv.org/abs/1706.05587)(CVPR 2017) |
| 19 | +
|
| 20 | + Args: |
| 21 | + image_encoder: `keras.Model`. An instance that is used as a feature |
| 22 | + extractor for the Encoder. Should either be a |
| 23 | + `keras_hub.models.Backbone` or a `keras.Model` that implements the |
| 24 | + `pyramid_outputs` property with keys "P2", "P3" etc as values. |
| 25 | + A somewhat sensible backbone to use in many cases is |
| 26 | + the `keras_hub.models.ResNetBackbone.from_preset("resnet_v2_50")`. |
| 27 | + projection_filters: int. Number of filters in the convolution layer |
| 28 | + projecting low-level features from the `image_encoder`. |
| 29 | + spatial_pyramid_pooling_key: str. A layer level to extract and perform |
| 30 | + `spatial_pyramid_pooling`, one of the key from the `image_encoder` |
| 31 | + `pyramid_outputs` property such as "P4", "P5" etc. |
| 32 | + upsampling_size: int or tuple of 2 integers. The upsampling factors for |
| 33 | + rows and columns of `spatial_pyramid_pooling` layer. |
| 34 | + If `low_level_feature_key` is given then `spatial_pyramid_pooling`s |
| 35 | + layer resolution should match with the `low_level_feature`s layer |
| 36 | + resolution to concatenate both the layers for combined encoder |
| 37 | + outputs. |
| 38 | + dilation_rates: list. A `list` of integers for parallel dilated conv applied to |
| 39 | + `SpatialPyramidPooling`. Usually a |
| 40 | + sample choice of rates are `[6, 12, 18]`. |
| 41 | + low_level_feature_key: str optional. A layer level to extract the feature |
| 42 | + from one of the key from the `image_encoder`s `pyramid_outputs` |
| 43 | + property such as "P2", "P3" etc which will be the Decoder block. |
| 44 | + Required only when the DeepLabV3Plus architecture needs to be applied. |
| 45 | + image_shape: tuple. The input shape without the batch size. |
| 46 | + Defaults to `(None, None, 3)`. |
| 47 | +
|
| 48 | + Example: |
| 49 | + ```python |
| 50 | + # Load a trained backbone to extract features from it's `pyramid_outputs`. |
| 51 | + image_encoder = keras_hub.models.ResNetBackbone.from_preset("resnet_50_imagenet") |
| 52 | +
|
| 53 | + model = keras_hub.models.DeepLabV3Backbone( |
| 54 | + image_encoder=image_encoder, |
| 55 | + projection_filters=48, |
| 56 | + low_level_feature_key="P2", |
| 57 | + spatial_pyramid_pooling_key="P5", |
| 58 | + upsampling_size = 8, |
| 59 | + dilation_rates = [6, 12, 18] |
| 60 | + ) |
| 61 | + ``` |
| 62 | + """ |
| 63 | + |
| 64 | + def __init__( |
| 65 | + self, |
| 66 | + image_encoder, |
| 67 | + spatial_pyramid_pooling_key, |
| 68 | + upsampling_size, |
| 69 | + dilation_rates, |
| 70 | + low_level_feature_key=None, |
| 71 | + projection_filters=48, |
| 72 | + image_shape=(None, None, 3), |
| 73 | + **kwargs, |
| 74 | + ): |
| 75 | + if not isinstance(image_encoder, keras.Model): |
| 76 | + raise ValueError( |
| 77 | + "Argument `image_encoder` must be a `keras.Model` instance. Received instead " |
| 78 | + f"{image_encoder} (of type {type(image_encoder)})." |
| 79 | + ) |
| 80 | + data_format = keras.config.image_data_format() |
| 81 | + channel_axis = -1 if data_format == "channels_last" else 1 |
| 82 | + |
| 83 | + # === Layers === |
| 84 | + inputs = keras.layers.Input(image_shape, name="inputs") |
| 85 | + |
| 86 | + fpn_model = keras.Model( |
| 87 | + image_encoder.inputs, image_encoder.pyramid_outputs |
| 88 | + ) |
| 89 | + |
| 90 | + fpn_outputs = fpn_model(inputs) |
| 91 | + |
| 92 | + spatial_pyramid_pooling = SpatialPyramidPooling( |
| 93 | + dilation_rates=dilation_rates |
| 94 | + ) |
| 95 | + spatial_backbone_features = fpn_outputs[spatial_pyramid_pooling_key] |
| 96 | + spp_outputs = spatial_pyramid_pooling(spatial_backbone_features) |
| 97 | + |
| 98 | + encoder_outputs = keras.layers.UpSampling2D( |
| 99 | + size=upsampling_size, |
| 100 | + interpolation="bilinear", |
| 101 | + name="encoder_output_upsampling", |
| 102 | + data_format=data_format, |
| 103 | + )(spp_outputs) |
| 104 | + |
| 105 | + if low_level_feature_key: |
| 106 | + decoder_feature = fpn_outputs[low_level_feature_key] |
| 107 | + low_level_projected_features = apply_low_level_feature_network( |
| 108 | + decoder_feature, projection_filters, channel_axis |
| 109 | + ) |
| 110 | + |
| 111 | + encoder_outputs = keras.layers.Concatenate( |
| 112 | + axis=channel_axis, name="encoder_decoder_concat" |
| 113 | + )([encoder_outputs, low_level_projected_features]) |
| 114 | + # upsampling to the original image size |
| 115 | + upsampling = (2 ** int(spatial_pyramid_pooling_key[-1])) // ( |
| 116 | + int(upsampling_size[0]) |
| 117 | + if isinstance(upsampling_size, tuple) |
| 118 | + else upsampling_size |
| 119 | + ) |
| 120 | + # === Functional Model === |
| 121 | + x = keras.layers.Conv2D( |
| 122 | + name="segmentation_head_conv", |
| 123 | + filters=256, |
| 124 | + kernel_size=1, |
| 125 | + padding="same", |
| 126 | + use_bias=False, |
| 127 | + data_format=data_format, |
| 128 | + )(encoder_outputs) |
| 129 | + x = keras.layers.BatchNormalization( |
| 130 | + name="segmentation_head_norm", axis=channel_axis |
| 131 | + )(x) |
| 132 | + x = keras.layers.ReLU(name="segmentation_head_relu")(x) |
| 133 | + x = keras.layers.UpSampling2D( |
| 134 | + size=upsampling, |
| 135 | + interpolation="bilinear", |
| 136 | + data_format=data_format, |
| 137 | + name="backbone_output_upsampling", |
| 138 | + )(x) |
| 139 | + |
| 140 | + super().__init__(inputs=inputs, outputs=x, **kwargs) |
| 141 | + |
| 142 | + # === Config === |
| 143 | + self.image_shape = image_shape |
| 144 | + self.image_encoder = image_encoder |
| 145 | + self.projection_filters = projection_filters |
| 146 | + self.upsampling_size = upsampling_size |
| 147 | + self.dilation_rates = dilation_rates |
| 148 | + self.low_level_feature_key = low_level_feature_key |
| 149 | + self.spatial_pyramid_pooling_key = spatial_pyramid_pooling_key |
| 150 | + |
| 151 | + def get_config(self): |
| 152 | + config = super().get_config() |
| 153 | + config.update( |
| 154 | + { |
| 155 | + "image_encoder": keras.saving.serialize_keras_object( |
| 156 | + self.image_encoder |
| 157 | + ), |
| 158 | + "projection_filters": self.projection_filters, |
| 159 | + "dilation_rates": self.dilation_rates, |
| 160 | + "upsampling_size": self.upsampling_size, |
| 161 | + "low_level_feature_key": self.low_level_feature_key, |
| 162 | + "spatial_pyramid_pooling_key": self.spatial_pyramid_pooling_key, |
| 163 | + "image_shape": self.image_shape, |
| 164 | + } |
| 165 | + ) |
| 166 | + return config |
| 167 | + |
| 168 | + @classmethod |
| 169 | + def from_config(cls, config): |
| 170 | + if "image_encoder" in config and isinstance( |
| 171 | + config["image_encoder"], dict |
| 172 | + ): |
| 173 | + config["image_encoder"] = keras.layers.deserialize( |
| 174 | + config["image_encoder"] |
| 175 | + ) |
| 176 | + return super().from_config(config) |
| 177 | + |
| 178 | + |
| 179 | +def apply_low_level_feature_network( |
| 180 | + input_tensor, projection_filters, channel_axis |
| 181 | +): |
| 182 | + data_format = keras.config.image_data_format() |
| 183 | + x = keras.layers.Conv2D( |
| 184 | + name="decoder_conv", |
| 185 | + filters=projection_filters, |
| 186 | + kernel_size=1, |
| 187 | + padding="same", |
| 188 | + use_bias=False, |
| 189 | + data_format=data_format, |
| 190 | + )(input_tensor) |
| 191 | + |
| 192 | + x = keras.layers.BatchNormalization(name="decoder_norm", axis=channel_axis)( |
| 193 | + x |
| 194 | + ) |
| 195 | + x = keras.layers.ReLU(name="decoder_relu")(x) |
| 196 | + return x |
0 commit comments