Skip to content

Commit 32fc430

Browse files
committed
Take two of keras-team#1812, simpler classifier head loading
Let's get rid of `load_task_extras`, which is a bad and confusing name. Instead, we will adopt some behavior that is specific to classifiers, but a lot simpler. ```python # Random head. classifier = ImageClassifier.from_preset("resnet50", num_classes=2) # Pretrained head. classifier = ImageClassifier.from_preset("resnet50") # Error, must provide num_classes. classifier = TextClassifier.from_preset("bert_base_en") ```
1 parent d0bb822 commit 32fc430

File tree

6 files changed

+31
-53
lines changed

6 files changed

+31
-53
lines changed

keras_nlp/src/models/preprocessor.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ def presets(cls):
126126
def from_preset(
127127
cls,
128128
preset,
129-
load_task_extras=False,
130129
**kwargs,
131130
):
132131
"""Instantiate a `keras_nlp.models.Preprocessor` from a model preset.
@@ -150,9 +149,6 @@ def from_preset(
150149
Args:
151150
preset: string. A built-in preset identifier, a Kaggle Models
152151
handle, a Hugging Face handle, or a path to a local directory.
153-
load_task_extras: bool. If `True`, load the saved task preprocessing
154-
configuration from a `preprocessing.json`. You might use this to
155-
restore the sequence length a model was fine-tuned with.
156152
157153
Examples:
158154
```python
@@ -179,7 +175,7 @@ def from_preset(
179175
# Detect the correct subclass if we need to.
180176
if cls.backbone_cls != backbone_cls:
181177
cls = find_subclass(preset, cls, backbone_cls)
182-
return loader.load_preprocessor(cls, load_task_extras, **kwargs)
178+
return loader.load_preprocessor(cls, **kwargs)
183179

184180
def save_to_preset(self, preset_dir):
185181
"""Save preprocessor to a preset directory.

keras_nlp/src/models/preprocessor_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,5 @@ def test_save_to_preset(self, cls, preset_name):
121121
self.assertEqual(set(tokenizer.file_assets), set(expected_assets))
122122

123123
# Check restore.
124-
restored = cls.from_preset(save_dir, load_task_extras=True)
124+
restored = cls.from_preset(save_dir)
125125
self.assertEqual(preprocessor.get_config(), restored.get_config())
126-
restored = cls.from_preset(save_dir, load_task_extras=False)
127-
self.assertNotEqual(preprocessor.get_config(), restored.get_config())

keras_nlp/src/models/task.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ def from_preset(
139139
cls,
140140
preset,
141141
load_weights=True,
142-
load_task_extras=False,
143142
**kwargs,
144143
):
145144
"""Instantiate a `keras_nlp.models.Task` from a model preset.
@@ -168,10 +167,6 @@ def from_preset(
168167
load_weights: bool. If `True`, saved weights will be loaded into
169168
the model architecture. If `False`, all weights will be
170169
randomly initialized.
171-
load_task_extras: bool. If `True`, load the saved task configuration
172-
from a `task.json` and any task specific weights from
173-
`task.weights`. You might use this to load a classification
174-
head for a model that has been saved with it.
175170
176171
Examples:
177172
```python
@@ -199,7 +194,12 @@ def from_preset(
199194
# Detect the correct subclass if we need to.
200195
if cls.backbone_cls != backbone_cls:
201196
cls = find_subclass(preset, cls, backbone_cls)
202-
return loader.load_task(cls, load_weights, load_task_extras, **kwargs)
197+
# Specifically for classifiers, we never load task weights if
198+
# num_classes is supplied. We handle this in the task base class because
199+
# it is the some logic for classifiers regardless of modality (text,
200+
# images, audio).
201+
load_task_weights = "num_classes" not in kwargs
202+
return loader.load_task(cls, load_weights, load_task_weights, **kwargs)
203203

204204
def load_task_weights(self, filepath):
205205
"""Load only the tasks specific weights not in the backbone."""

keras_nlp/src/models/task_test.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,20 +138,16 @@ def test_save_to_preset(self):
138138
self.assertEqual(BertTextClassifier, check_config_class(task_config))
139139

140140
# Try loading the model from preset directory.
141-
restored_task = TextClassifier.from_preset(
142-
save_dir, load_task_extras=True
143-
)
141+
restored_task = TextClassifier.from_preset(save_dir)
144142

145143
# Check the model output.
146144
data = ["the quick brown fox.", "the slow brown fox."]
147145
ref_out = task.predict(data)
148146
new_out = restored_task.predict(data)
149147
self.assertAllClose(ref_out, new_out)
150148

151-
# Load without head weights.
152-
restored_task = TextClassifier.from_preset(
153-
save_dir, load_task_extras=False, num_classes=2
154-
)
149+
# Load without head different head weights.
150+
restored_task = TextClassifier.from_preset(save_dir, num_classes=2)
155151
data = ["the quick brown fox.", "the slow brown fox."]
156152
# Full output unequal.
157153
ref_out = task.predict(data)

keras_nlp/src/models/text_classifier.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ class TextClassifier(Task):
3232
All `TextClassifier` tasks include a `from_preset()` constructor which can be
3333
used to load a pre-trained config and weights.
3434
35+
Some classification presets (but not all), include classification head
36+
weights in a `task.weights.h5`. For these presets, you can omit passing
37+
`num_classes` to re-create the save classification head. For all presets, if
38+
`num_classes` is passed as a kwarg to `from_preset()`, the classification
39+
head will be randomly initialized.
40+
3541
Example:
3642
```python
3743
# Load a BERT classifier with pre-trained weights.

keras_nlp/src/utils/preset_utils.py

Lines changed: 14 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ def load_image_converter(self, cls, **kwargs):
673673
"""Load an image converter layer from the preset."""
674674
raise NotImplementedError
675675

676-
def load_task(self, cls, load_weights, load_task_extras, **kwargs):
676+
def load_task(self, cls, load_weights, load_task_weights, **kwargs):
677677
"""Load a task model from the preset.
678678
679679
By default, we create a task from a backbone and preprocessor with
@@ -689,11 +689,10 @@ def load_task(self, cls, load_weights, load_task_extras, **kwargs):
689689
if "preprocessor" not in kwargs:
690690
kwargs["preprocessor"] = self.load_preprocessor(
691691
cls.preprocessor_cls,
692-
load_task_extras=load_task_extras,
693692
)
694693
return cls(**kwargs)
695694

696-
def load_preprocessor(self, cls, load_task_extras, **kwargs):
695+
def load_preprocessor(self, cls, **kwargs):
697696
"""Load a prepocessor layer from the preset.
698697
699698
By default, we create a preprocessor from a tokenizer with default
@@ -738,33 +737,25 @@ def load_image_converter(self, cls, **kwargs):
738737
converter_config = load_json(self.preset, IMAGE_CONVERTER_CONFIG_FILE)
739738
return load_serialized_object(converter_config, **kwargs)
740739

741-
def load_task(self, cls, load_weights, load_task_extras, **kwargs):
740+
def load_task(self, cls, load_weights, load_task_weights, **kwargs):
742741
# If there is no `task.json` or it's for the wrong class delegate to the
743742
# super class loader.
744-
if not load_task_extras:
745-
return super().load_task(
746-
cls, load_weights, load_task_extras, **kwargs
747-
)
748743
if not check_file_exists(self.preset, TASK_CONFIG_FILE):
749-
raise ValueError(
750-
"Saved preset has no `task.json`, cannot load the task config "
751-
"from a file. Call `from_preset()` with "
752-
"`load_task_extras=False` to load the task from a backbone "
753-
"with library defaults."
744+
return super().load_task(
745+
cls, load_weights, load_task_weights, **kwargs
754746
)
755747
task_config = load_json(self.preset, TASK_CONFIG_FILE)
756748
if not issubclass(check_config_class(task_config), cls):
757-
raise ValueError(
758-
f"Saved `task.json`does not match calling cls {cls}. Call "
759-
"`from_preset()` with `load_task_extras=False` to load the "
760-
"task from a backbone with library defaults."
749+
return super().load_task(
750+
cls, load_weights, load_task_weights, **kwargs
761751
)
762752
# We found a `task.json` with a complete config for our class.
763753
task = load_serialized_object(task_config, **kwargs)
764754
if task.preprocessor is not None:
765755
task.preprocessor.tokenizer.load_preset_assets(self.preset)
766756
if load_weights:
767-
if check_file_exists(self.preset, TASK_WEIGHTS_FILE):
757+
has_task_weights = check_file_exists(self.preset, TASK_WEIGHTS_FILE)
758+
if has_task_weights and load_task_weights:
768759
jax_memory_cleanup(task)
769760
task_weights = get_file(self.preset, TASK_WEIGHTS_FILE)
770761
task.load_task_weights(task_weights)
@@ -774,23 +765,14 @@ def load_task(self, cls, load_weights, load_task_extras, **kwargs):
774765
task.backbone.load_weights(backbone_weights)
775766
return task
776767

777-
def load_preprocessor(self, cls, load_task_extras, **kwargs):
778-
if not load_task_extras:
779-
return super().load_preprocessor(cls, load_task_extras, **kwargs)
768+
def load_preprocessor(self, cls, **kwargs):
769+
# If there is no `preprocessing.json` or it's for the wrong class,
770+
# delegate to the super class loader.
780771
if not check_file_exists(self.preset, PREPROCESSOR_CONFIG_FILE):
781-
raise ValueError(
782-
"Saved preset has no `preprocessor.json`, cannot load the task "
783-
"preprocessing config from a file. Call `from_preset()` with "
784-
"`load_task_extras=False` to load the preprocessor with "
785-
"library defaults."
786-
)
772+
return super().load_preprocessor(cls, **kwargs)
787773
preprocessor_json = load_json(self.preset, PREPROCESSOR_CONFIG_FILE)
788774
if not issubclass(check_config_class(preprocessor_json), cls):
789-
raise ValueError(
790-
f"Saved `preprocessor.json`does not match calling cls {cls}. "
791-
"Call `from_preset()` with `load_task_extras=False` to "
792-
"load the the preprocessor with library defaults."
793-
)
775+
return super().load_preprocessor(cls, **kwargs)
794776
# We found a `preprocessing.json` with a complete config for our class.
795777
preprocessor = load_serialized_object(preprocessor_json, **kwargs)
796778
preprocessor.tokenizer.load_preset_assets(self.preset)

0 commit comments

Comments
 (0)