Skip to content

Only load a full task config when load_task_extras is passed #1812

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion keras_nlp/src/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def presets(cls):
def from_preset(
cls,
preset,
load_task=False,
**kwargs,
):
"""Instantiate a `keras_nlp.models.Preprocessor` from a model preset.
Expand All @@ -112,6 +113,9 @@ def from_preset(
Args:
preset: string. A built in preset identifier, a Kaggle Models
handle, a Hugging Face handle, or a path to a local directory.
load_task: bool. If `True`, load the saved task preprocessing
configuration from a `preprocessing.json`. You might use this to
restore the sequence length a model was fine-tuned with.

Examples:
```python
Expand All @@ -138,7 +142,7 @@ def from_preset(
# Detect the correct subclass if we need to.
if cls.backbone_cls != backbone_cls:
cls = find_subclass(preset, cls, backbone_cls)
return loader.load_preprocessor(cls, **kwargs)
return loader.load_preprocessor(cls, load_task, **kwargs)

def save_to_preset(self, preset_dir):
"""Save preprocessor to a preset directory.
Expand Down
53 changes: 25 additions & 28 deletions keras_nlp/src/models/preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib

import pytest
from absl.testing import parameterized
Expand All @@ -31,10 +32,11 @@
RobertaTextClassifierPreprocessor,
)
from keras_nlp.src.tests.test_case import TestCase
from keras_nlp.src.utils.preset_utils import PREPROCESSOR_CONFIG_FILE
from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
from keras_nlp.src.tokenizers.sentence_piece_tokenizer import (
SentencePieceTokenizer,
)
from keras_nlp.src.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import load_json


class TestPreprocessor(TestCase):
Expand Down Expand Up @@ -80,45 +82,40 @@ def test_from_preset_errors(self):
# TODO: Add more tests when we added a model that has `preprocessor.json`.

@parameterized.parameters(
(
AlbertTextClassifierPreprocessor,
"albert_base_en_uncased",
"sentencepiece",
),
(RobertaTextClassifierPreprocessor, "roberta_base_en", "bytepair"),
(BertTextClassifierPreprocessor, "bert_tiny_en_uncased", "wordpiece"),
(AlbertTextClassifierPreprocessor, "albert_base_en_uncased"),
(RobertaTextClassifierPreprocessor, "roberta_base_en"),
(BertTextClassifierPreprocessor, "bert_tiny_en_uncased"),
)
@pytest.mark.large
def test_save_to_preset(self, cls, preset_name, tokenizer_type):
def test_save_to_preset(self, cls, preset_name):
save_dir = self.get_temp_dir()
preprocessor = cls.from_preset(preset_name)
preprocessor = cls.from_preset(preset_name, sequence_length=100)
tokenizer = preprocessor.tokenizer
preprocessor.save_to_preset(save_dir)
# Save a backbone so the preset is valid.
backbone = cls.backbone_cls.from_preset(preset_name, load_weights=False)
backbone.save_to_preset(save_dir)

if tokenizer_type == "bytepair":
if isinstance(tokenizer, BytePairTokenizer):
vocab_filename = "vocabulary.json"
expected_assets = [
"vocabulary.json",
"merges.txt",
]
elif tokenizer_type == "sentencepiece":
expected_assets = ["vocabulary.json", "merges.txt"]
elif isinstance(tokenizer, SentencePieceTokenizer):
vocab_filename = "vocabulary.spm"
expected_assets = ["vocabulary.spm"]
else:
vocab_filename = "vocabulary.txt"
expected_assets = ["vocabulary.txt"]

# Check existence of vocab file.
vocab_path = os.path.join(
save_dir, os.path.join(TOKENIZER_ASSET_DIR, vocab_filename)
)
path = pathlib.Path(save_dir)
vocab_path = path / TOKENIZER_ASSET_DIR / vocab_filename
self.assertTrue(os.path.exists(vocab_path))

# Check assets.
self.assertEqual(
set(preprocessor.tokenizer.file_assets),
set(expected_assets),
)
self.assertEqual(set(tokenizer.file_assets), set(expected_assets))

# Check config class.
preprocessor_config = load_json(save_dir, PREPROCESSOR_CONFIG_FILE)
self.assertEqual(cls, check_config_class(preprocessor_config))
# Check restore.
restored = cls.from_preset(save_dir, load_task=True)
self.assertEqual(preprocessor.get_config(), restored.get_config())
restored = cls.from_preset(save_dir, load_task=False)
self.assertNotEqual(preprocessor.get_config(), restored.get_config())
16 changes: 11 additions & 5 deletions keras_nlp/src/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def from_preset(
cls,
preset,
load_weights=True,
load_task=False,
**kwargs,
):
"""Instantiate a `keras_nlp.models.Task` from a model preset.
Expand All @@ -171,9 +172,13 @@ def from_preset(
Args:
preset: string. A built in preset identifier, a Kaggle Models
handle, a Hugging Face handle, or a path to a local directory.
load_weights: bool. If `True`, the weights will be loaded into the
model architecture. If `False`, the weights will be randomly
initialized.
load_weights: bool. If `True`, the backbone weights will be loaded
into the model architecture. If `False`, the weights will be
randomly initialized.
load_task: bool. If `True`, load the saved task configuration
from a `task.json` and any task specific weights from
`task.weights`. You might use this to load a classification
head for a model that has been saved with it.

Examples:
```python
Expand Down Expand Up @@ -201,13 +206,14 @@ def from_preset(
# Detect the correct subclass if we need to.
if cls.backbone_cls != backbone_cls:
cls = find_subclass(preset, cls, backbone_cls)
return loader.load_task(cls, load_weights, **kwargs)
return loader.load_task(cls, load_weights, load_task, **kwargs)

def load_task_weights(self, filepath):
"""Load only the tasks specific weights not in the backbone."""
if not str(filepath).endswith(".weights.h5"):
raise ValueError(
"The filename must end in `.weights.h5`. Received: filepath={filepath}"
"The filename must end in `.weights.h5`. "
f"Received: filepath={filepath}"
)
backbone_layer_ids = set(id(w) for w in self.backbone._flatten_layers())
keras.saving.load_weights(
Expand Down
47 changes: 28 additions & 19 deletions keras_nlp/src/models/task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
import pathlib

import keras
import pytest
Expand Down Expand Up @@ -109,23 +110,16 @@ def test_summary_without_preprocessor(self):
@pytest.mark.large
def test_save_to_preset(self):
save_dir = self.get_temp_dir()
model = TextClassifier.from_preset(
"bert_tiny_en_uncased", num_classes=2
)
model.save_to_preset(save_dir)
task = TextClassifier.from_preset("bert_tiny_en_uncased", num_classes=2)
task.save_to_preset(save_dir)

# Check existence of files.
self.assertTrue(os.path.exists(os.path.join(save_dir, CONFIG_FILE)))
self.assertTrue(
os.path.exists(os.path.join(save_dir, MODEL_WEIGHTS_FILE))
)
self.assertTrue(os.path.exists(os.path.join(save_dir, METADATA_FILE)))
self.assertTrue(
os.path.exists(os.path.join(save_dir, TASK_CONFIG_FILE))
)
self.assertTrue(
os.path.exists(os.path.join(save_dir, TASK_WEIGHTS_FILE))
)
path = pathlib.Path(save_dir)
self.assertTrue(os.path.exists(path / CONFIG_FILE))
self.assertTrue(os.path.exists(path / MODEL_WEIGHTS_FILE))
self.assertTrue(os.path.exists(path / METADATA_FILE))
self.assertTrue(os.path.exists(path / TASK_CONFIG_FILE))
self.assertTrue(os.path.exists(path / TASK_WEIGHTS_FILE))

# Check the task config (`task.json`).
task_config = load_json(save_dir, TASK_CONFIG_FILE)
Expand All @@ -138,13 +132,28 @@ def test_save_to_preset(self):
self.assertEqual(BertTextClassifier, check_config_class(task_config))

# Try loading the model from preset directory.
restored_model = TextClassifier.from_preset(save_dir)
restored_task = TextClassifier.from_preset(save_dir, load_task=True)

# Check the model output.
data = ["the quick brown fox.", "the slow brown fox."]
ref_out = model.predict(data)
new_out = restored_model.predict(data)
self.assertAllEqual(ref_out, new_out)
ref_out = task.predict(data)
new_out = restored_task.predict(data)
self.assertAllClose(ref_out, new_out)

# Load without head weights.
restored_task = TextClassifier.from_preset(
save_dir, load_task=False, num_classes=2
)
data = ["the quick brown fox.", "the slow brown fox."]
# Full output unequal.
ref_out = task.predict(data)
new_out = restored_task.predict(data)
self.assertNotAllClose(ref_out, new_out)
# Backbone output equal.
data = task.preprocessor(data)
ref_out = task.backbone.predict(data)
new_out = restored_task.backbone.predict(data)
self.assertAllClose(ref_out, new_out)

@pytest.mark.large
def test_none_preprocessor(self):
Expand Down
44 changes: 32 additions & 12 deletions keras_nlp/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ def load_tokenizer(self, cls, **kwargs):
"""Load a tokenizer layer from the preset."""
raise NotImplementedError

def load_task(self, cls, load_weights, **kwargs):
def load_task(self, cls, load_weights, load_task, **kwargs):
"""Load a task model from the preset.

By default, we create a task from a backbone and preprocessor with
Expand All @@ -671,11 +671,12 @@ def load_task(self, cls, load_weights, **kwargs):
)
if "preprocessor" not in kwargs:
kwargs["preprocessor"] = self.load_preprocessor(
cls.preprocessor_cls
cls.preprocessor_cls,
load_task=load_task,
)
return cls(**kwargs)

def load_preprocessor(self, cls, **kwargs):
def load_preprocessor(self, cls, load_task, **kwargs):
"""Load a prepocessor layer from the preset.

By default, we create a preprocessor from a tokenizer with default
Expand Down Expand Up @@ -704,35 +705,54 @@ def load_tokenizer(self, cls, **kwargs):
tokenizer.load_preset_assets(self.preset)
return tokenizer

def load_task(self, cls, load_weights, **kwargs):
def load_task(self, cls, load_weights, load_task, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this a little bit confusing that the function is called load_task and the new flag is also called load_task?
If the load_task flag is false, it only loads the backbone so it's not loading the task anymore so the function name becomes confusing 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good call! This function isn't exposed at least but I'll think.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe load_task_extras would be a better name for this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying this out.

# If there is no `task.json` or it's for the wrong class delegate to the
# super class loader.
if not load_task:
return super().load_task(cls, load_weights, load_task, **kwargs)
if not check_file_exists(self.preset, TASK_CONFIG_FILE):
return super().load_task(cls, load_weights, **kwargs)
raise ValueError(
"Saved preset has no `task.json`, cannot load the task config "
"from a file. Call `from_preset()` with `load_task=False` to "
"load the task from a backbone with library defaults."
)
task_config = load_json(self.preset, TASK_CONFIG_FILE)
if not issubclass(check_config_class(task_config), cls):
return super().load_task(cls, load_weights, **kwargs)
raise ValueError(
f"Saved `task.json`does not match calling cls {cls}. Call "
"`from_preset()` with `load_task=False` to load the task from "
"a backbone with library defaults."
)
# We found a `task.json` with a complete config for our class.
task = load_serialized_object(task_config, **kwargs)
if task.preprocessor is not None:
task.preprocessor.tokenizer.load_preset_assets(self.preset)
if load_weights:
jax_memory_cleanup(task)
jax_memory_cleanup(task.backbone)
if check_file_exists(self.preset, TASK_WEIGHTS_FILE):
task_weights = get_file(self.preset, TASK_WEIGHTS_FILE)
task.load_task_weights(task_weights)
backbone_weights = get_file(self.preset, MODEL_WEIGHTS_FILE)
task.backbone.load_weights(backbone_weights)
return task

def load_preprocessor(self, cls, **kwargs):
# If there is no `preprocessing.json` or it's for the wrong class,
# delegate to the super class loader.
def load_preprocessor(self, cls, load_task, **kwargs):
if not load_task:
return super().load_preprocessor(cls, load_task, **kwargs)
if not check_file_exists(self.preset, PREPROCESSOR_CONFIG_FILE):
return super().load_preprocessor(cls, **kwargs)
raise ValueError(
"Saved preset has no `preprocessor.json`, cannot load the task "
"preprocessing config from a file. Call `from_preset()` with "
"`load_task=False` to load the preprocessor with library "
"defaults."
)
preprocessor_json = load_json(self.preset, PREPROCESSOR_CONFIG_FILE)
if not issubclass(check_config_class(preprocessor_json), cls):
return super().load_preprocessor(cls, **kwargs)
raise ValueError(
f"Saved `preprocessor.json`does not match calling cls {cls}. "
"Call `from_preset()` with `load_task=False` to "
"load the the preprocessor with library defaults."
)
# We found a `preprocessing.json` with a complete config for our class.
preprocessor = load_serialized_object(preprocessor_json, **kwargs)
preprocessor.tokenizer.load_preset_assets(self.preset)
Expand Down
Loading