Skip to content

Commit 87c5f61

Browse files
committed
Update Keras Functional model docstring for the new feature of submodel slicing.
PiperOrigin-RevId: 399303217
1 parent 89fd0f5 commit 87c5f61

File tree

2 files changed

+57
-2
lines changed

2 files changed

+57
-2
lines changed

keras/engine/functional.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,43 @@ class Functional(training_lib.Model):
7272
7373
Example:
7474
75-
```
75+
```python
7676
inputs = keras.Input(shape=(10,))
7777
x = keras.layers.Dense(1)(inputs)
7878
outputs = tf.nn.relu(x)
7979
model = keras.Model(inputs, outputs)
8080
```
8181
82+
A new `Functional` model can also be created by using the
83+
intermediate tensors. This enables you to quickly extract sub-components
84+
of the model.
85+
86+
Example:
87+
88+
```python
89+
inputs = keras.Input(shape=(None, None, 3))
90+
processed = keras.layers.RandomCrop(width=32, height=32)(inputs)
91+
conv = keras.layers.Conv2D(filters=2, kernel_size=3)(processed)
92+
pooling = keras.layers.GlobalAveragePooling2D()(conv)
93+
feature = keras.layers.Dense(10)(pooling)
94+
95+
full_model = keras.Model(inputs, feature)
96+
backbone = keras.Model(processed, conv)
97+
activations = keras.Model(conv, feature)
98+
```
99+
100+
Note that the `backbone` and `activations` models are not
101+
created with `keras.Input` objects, but with the tensors that are originated
102+
from `keras.Inputs` objects. Under the hood, the layers and weights will
103+
be shared across these models, so that user can train the `full_model`, and
104+
use `backbone` or `activations` to do feature extraction.
105+
The inputs and outputs of the model can be nested structures of tensors as
106+
well, and the created models are standard `Functional` model that support
107+
all the existing API.
108+
82109
Args:
83-
inputs: List of input tensors (must be created via `tf.keras.Input()`).
110+
inputs: List of input tensors (must be created via `tf.keras.Input()` or
111+
originated from `tf.keras.Input()`).
84112
outputs: List of output tensors.
85113
name: String, optional. Name of the model.
86114
trainable: Boolean, optional. If the model's variables should be trainable.

keras/engine/training.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,33 @@ class Model(base_layer.Layer, version_utils.ModelVersionSelector):
9090
Note: Only dicts, lists, and tuples of input tensors are supported. Nested
9191
inputs are not supported (e.g. lists of list or dicts of dict).
9292
93+
A new Functional API model can also be created by using the
94+
intermediate tensors. This enables you to quickly extract sub-components
95+
of the model.
96+
97+
Example:
98+
99+
```python
100+
inputs = keras.Input(shape=(None, None, 3))
101+
processed = keras.layers.RandomCrop(width=32, height=32)(inputs)
102+
conv = keras.layers.Conv2D(filters=2, kernel_size=3)(processed)
103+
pooling = keras.layers.GlobalAveragePooling2D()(conv)
104+
feature = keras.layers.Dense(10)(pooling)
105+
106+
full_model = keras.Model(inputs, feature)
107+
backbone = keras.Model(processed, conv)
108+
activations = keras.Model(conv, feature)
109+
```
110+
111+
Note that the `backbone` and `activations` models are not
112+
created with `keras.Input` objects, but with the tensors that are originated
113+
from `keras.Inputs` objects. Under the hood, the layers and weights will
114+
be shared across these models, so that user can train the `full_model`, and
115+
use `backbone` or `activations` to do feature extraction.
116+
The inputs and outputs of the model can be nested structures of tensors as
117+
well, and the created models are standard Functional API models that support
118+
all the existing APIs.
119+
93120
2 - By subclassing the `Model` class: in that case, you should define your
94121
layers in `__init__()` and you should implement the model's forward pass
95122
in `call()`.

0 commit comments

Comments
 (0)