-
Notifications
You must be signed in to change notification settings - Fork 285
Add VGG16 and VGG19 backbone #1737
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
divyashreepathihalli
merged 25 commits into
keras-team:keras-hub
from
divyashreepathihalli:VGG16
Aug 8, 2024
Merged
Changes from all commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
246c242
Agg Vgg16 backbone
divyashreepathihalli 25fddb0
update names
divyashreepathihalli 0d3414f
update tests
divyashreepathihalli a0c1a72
update test
divyashreepathihalli 3d57f73
add image classifier
divyashreepathihalli c23d573
incorporate review comments
divyashreepathihalli 0e73b6f
Update test case
divyashreepathihalli fac566f
update backbone test
divyashreepathihalli eef4405
add image classifier
divyashreepathihalli 0c481ef
classifier cleanup
divyashreepathihalli 5206dc9
code reformat
divyashreepathihalli 1bcd5b2
add vgg16 image classifier
divyashreepathihalli a8b4bf2
make vgg generic
divyashreepathihalli 41a8733
update doc string
divyashreepathihalli 40ad2ed
update docstring
divyashreepathihalli b1a6dfd
add classifier test
divyashreepathihalli 443af98
update tests
divyashreepathihalli eb818d1
update docstring
divyashreepathihalli f8c92c2
address review comments
divyashreepathihalli 7dae882
code reformat
divyashreepathihalli cbf5ed7
update the configs
divyashreepathihalli 483e7bd
address review comments
divyashreepathihalli d8a6745
fix task saved model test
divyashreepathihalli 5f223e5
update init
divyashreepathihalli 901f7ae
code reformatted
divyashreepathihalli File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright 2023 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import keras | ||
|
||
from keras_nlp.src.api_export import keras_nlp_export | ||
from keras_nlp.src.models.task import Task | ||
|
||
|
||
@keras_nlp_export("keras_nlp.models.ImageClassifier") | ||
class ImageClassifier(Task): | ||
"""Base class for all image classification tasks. | ||
|
||
`ImageClassifier` tasks wrap a `keras_nlp.models.Backbone` and | ||
a `keras_nlp.models.Preprocessor` to create a model that can be used for | ||
image classification. `ImageClassifier` tasks take an additional | ||
`num_classes` argument, controlling the number of predicted output classes. | ||
|
||
To fine-tune with `fit()`, pass a dataset containing tuples of `(x, y)` | ||
labels where `x` is a string and `y` is a integer from `[0, num_classes)`. | ||
|
||
All `ImageClassifier` tasks include a `from_preset()` constructor which can be | ||
used to load a pre-trained config and weights. | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
# Default compilation. | ||
self.compile() | ||
|
||
def compile( | ||
self, | ||
optimizer="auto", | ||
loss="auto", | ||
*, | ||
metrics="auto", | ||
**kwargs, | ||
): | ||
"""Configures the `ImageClassifier` task for training. | ||
|
||
The `ImageClassifier` task extends the default compilation signature of | ||
`keras.Model.compile` with defaults for `optimizer`, `loss`, and | ||
`metrics`. To override these defaults, pass any value | ||
to these arguments during compilation. | ||
|
||
Args: | ||
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` | ||
instance. Defaults to `"auto"`, which uses the default optimizer | ||
for the given model and task. See `keras.Model.compile` and | ||
`keras.optimizers` for more info on possible `optimizer` values. | ||
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. | ||
Defaults to `"auto"`, where a | ||
`keras.losses.SparseCategoricalCrossentropy` loss will be | ||
applied for the classification task. See | ||
`keras.Model.compile` and `keras.losses` for more info on | ||
possible `loss` values. | ||
metrics: `"auto"`, or a list of metrics to be evaluated by | ||
the model during training and testing. Defaults to `"auto"`, | ||
where a `keras.metrics.SparseCategoricalAccuracy` will be | ||
applied to track the accuracy of the model during training. | ||
See `keras.Model.compile` and `keras.metrics` for | ||
more info on possible `metrics` values. | ||
**kwargs: See `keras.Model.compile` for a full list of arguments | ||
supported by the compile method. | ||
""" | ||
if optimizer == "auto": | ||
optimizer = keras.optimizers.Adam(5e-5) | ||
if loss == "auto": | ||
activation = getattr(self, "activation", None) | ||
activation = keras.activations.get(activation) | ||
from_logits = activation != keras.activations.softmax | ||
loss = keras.losses.SparseCategoricalCrossentropy(from_logits) | ||
if metrics == "auto": | ||
metrics = [keras.metrics.SparseCategoricalAccuracy()] | ||
super().compile( | ||
optimizer=optimizer, | ||
loss=loss, | ||
metrics=metrics, | ||
**kwargs, | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2024 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# Copyright 2023 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import keras | ||
from keras import layers | ||
|
||
from keras_nlp.src.api_export import keras_nlp_export | ||
from keras_nlp.src.models.backbone import Backbone | ||
|
||
|
||
@keras_nlp_export("keras_nlp.models.VGGBackbone") | ||
class VGGBackbone(Backbone): | ||
""" | ||
This class represents Keras Backbone of VGG model. | ||
|
||
This class implements a VGG backbone as described in [Very Deep | ||
Convolutional Networks for Large-Scale Image Recognition]( | ||
https://arxiv.org/abs/1409.1556)(ICLR 2015). | ||
|
||
Args: | ||
stackwise_num_repeats: list of ints, number of repeated convolutional | ||
blocks per VGG block. For VGG16 this is [2, 2, 3, 3, 3] and for | ||
VGG19 this is [2, 2, 4, 4, 4]. | ||
stackwise_num_filters: list of ints, filter size for convolutional | ||
blocks per VGG block. For both VGG16 and VGG19 this is [ | ||
64, 128, 256, 512, 512]. | ||
include_rescaling: bool, whether to rescale the inputs. If set to | ||
True, inputs will be passed through a `Rescaling(1/255.0)` layer. | ||
input_shape: tuple, optional shape tuple, defaults to (224, 224, 3). | ||
pooling: bool, Optional pooling mode for feature extraction | ||
when `include_top` is `False`. | ||
- `None` means that the output of the model will be | ||
the 4D tensor output of the | ||
last convolutional block. | ||
- `avg` means that global average pooling | ||
will be applied to the output of the | ||
last convolutional block, and thus | ||
the output of the model will be a 2D tensor. | ||
- `max` means that global max pooling will | ||
be applied. | ||
|
||
Examples: | ||
```python | ||
input_data = np.ones((2, 224, 224, 3), dtype="float32") | ||
|
||
# Pretrained VGG backbone. | ||
model = keras_nlp.models.VGGBackbone.from_preset("vgg16") | ||
model(input_data) | ||
|
||
# Randomly initialized VGG backbone with a custom config. | ||
model = keras_nlp.models.VGGBackbone( | ||
stackwise_num_repeats = [2, 2, 3, 3, 3], | ||
stackwise_num_filters = [64, 128, 256, 512, 512], | ||
input_shape = (224, 224, 3), | ||
divyashreepathihalli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
include_rescaling = False, | ||
pooling = "avg", | ||
) | ||
model(input_data) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
stackwise_num_repeats, | ||
stackwise_num_filters, | ||
include_rescaling, | ||
input_image_shape=(224, 224, 3), | ||
pooling="avg", | ||
**kwargs, | ||
): | ||
|
||
# === Functional Model === | ||
img_input = keras.layers.Input(shape=input_image_shape) | ||
x = img_input | ||
|
||
if include_rescaling: | ||
x = layers.Rescaling(scale=1 / 255.0)(x) | ||
for stack_index in range(len(stackwise_num_repeats) - 1): | ||
x = apply_vgg_block( | ||
x=x, | ||
num_layers=stackwise_num_repeats[stack_index], | ||
filters=stackwise_num_filters[stack_index], | ||
kernel_size=(3, 3), | ||
activation="relu", | ||
padding="same", | ||
max_pool=True, | ||
name=f"block{stack_index + 1}", | ||
) | ||
if pooling == "avg": | ||
x = layers.GlobalAveragePooling2D()(x) | ||
elif pooling == "max": | ||
x = layers.GlobalMaxPooling2D()(x) | ||
|
||
super().__init__(inputs=img_input, outputs=x, **kwargs) | ||
|
||
# === Config === | ||
self.stackwise_num_repeats = stackwise_num_repeats | ||
self.stackwise_num_filters = stackwise_num_filters | ||
self.include_rescaling = include_rescaling | ||
self.input_image_shape = input_image_shape | ||
self.pooling = pooling | ||
|
||
def get_config(self): | ||
return { | ||
"stackwise_num_repeats": self.stackwise_num_repeats, | ||
"stackwise_num_filters": self.stackwise_num_filters, | ||
"include_rescaling": self.include_rescaling, | ||
"input_image_shape": self.input_image_shape, | ||
"pooling": self.pooling, | ||
} | ||
|
||
|
||
def apply_vgg_block( | ||
divyashreepathihalli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
x, | ||
num_layers, | ||
filters, | ||
kernel_size, | ||
activation, | ||
padding, | ||
max_pool, | ||
name, | ||
): | ||
""" | ||
Applies VGG block | ||
Args: | ||
x: Tensor, input tensor to pass through network | ||
num_layers: int, number of CNN layers in the block | ||
filters: int, filter size of each CNN layer in block | ||
kernel_size: int (or) tuple, kernel size for CNN layer in block | ||
activation: str (or) callable, activation function for each CNN layer in | ||
block | ||
padding: str (or) callable, padding function for each CNN layer in block | ||
max_pool: bool, whether to add MaxPooling2D layer at end of block | ||
name: str, name of the block | ||
|
||
Returns: | ||
keras.KerasTensor | ||
""" | ||
for num in range(1, num_layers + 1): | ||
x = layers.Conv2D( | ||
filters, | ||
kernel_size, | ||
activation=activation, | ||
padding=padding, | ||
name=f"{name}_conv{num}", | ||
)(x) | ||
if max_pool: | ||
x = layers.MaxPooling2D((2, 2), (2, 2), name=f"{name}_pool")(x) | ||
return x |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Copyright 2023 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone | ||
from keras_nlp.src.tests.test_case import TestCase | ||
|
||
|
||
class VGGBackboneTest(TestCase): | ||
def setUp(self): | ||
self.init_kwargs = { | ||
"stackwise_num_repeats": [2, 3, 3], | ||
"stackwise_num_filters": [8, 64, 64], | ||
"input_image_shape": (16, 16, 3), | ||
"include_rescaling": False, | ||
"pooling": "avg", | ||
} | ||
self.input_data = np.ones((2, 16, 16, 3), dtype="float32") | ||
|
||
def test_backbone_basics(self): | ||
self.run_backbone_test( | ||
cls=VGGBackbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
expected_output_shape=(2, 64), | ||
run_mixed_precision_check=False, | ||
) | ||
|
||
@pytest.mark.large | ||
def test_saved_model(self): | ||
self.run_model_saving_test( | ||
cls=VGGBackbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.