@@ -72,15 +72,43 @@ class Functional(training_lib.Model):
72
72
73
73
Example:
74
74
75
- ```
75
+ ```python
76
76
inputs = keras.Input(shape=(10,))
77
77
x = keras.layers.Dense(1)(inputs)
78
78
outputs = tf.nn.relu(x)
79
79
model = keras.Model(inputs, outputs)
80
80
```
81
81
82
+ A new `Functional` model can also be created by using the
83
+ intermediate tensors. This enables you to quickly extract sub-components
84
+ of the model.
85
+
86
+ Example:
87
+
88
+ ```python
89
+ inputs = keras.Input(shape=(None, None, 3))
90
+ processed = keras.layers.RandomCrop(width=32, height=32)(inputs)
91
+ conv = keras.layers.Conv2D(filters=2, kernel_size=3)(processed)
92
+ pooling = keras.layers.GlobalAveragePooling2D()(conv)
93
+ feature = keras.layers.Dense(10)(pooling)
94
+
95
+ full_model = keras.Model(inputs, feature)
96
+ backbone = keras.Model(processed, conv)
97
+ activations = keras.Model(conv, feature)
98
+ ```
99
+
100
+ Note that the `backbone` and `activations` models are not
101
+ created with `keras.Input` objects, but with the tensors that are originated
102
+ from `keras.Inputs` objects. Under the hood, the layers and weights will
103
+ be shared across these models, so that user can train the `full_model`, and
104
+ use `backbone` or `activations` to do feature extraction.
105
+ The inputs and outputs of the model can be nested structures of tensors as
106
+ well, and the created models are standard `Functional` model that support
107
+ all the existing API.
108
+
82
109
Args:
83
- inputs: List of input tensors (must be created via `tf.keras.Input()`).
110
+ inputs: List of input tensors (must be created via `tf.keras.Input()` or
111
+ originated from `tf.keras.Input()`).
84
112
outputs: List of output tensors.
85
113
name: String, optional. Name of the model.
86
114
trainable: Boolean, optional. If the model's variables should be trainable.
0 commit comments