Skip to content

Commit e5c6ddb

Browse files
committed
Some routine cleanup while writing some new tools for checkpoint admin
- Remove broken test in preset_utils we don't ever run - Move load_serialized_object to our preset loading class (for consistency) - Move all admin related tooling to a dedicated folder in tools/ - Remove some no longer used scripts.
1 parent 08a4681 commit e5c6ddb

File tree

6 files changed

+15
-137
lines changed

6 files changed

+15
-137
lines changed

keras_hub/src/utils/preset_utils.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -454,16 +454,6 @@ def load_json(preset, config_file=CONFIG_FILE):
454454
return config
455455

456456

457-
def load_serialized_object(config, **kwargs):
458-
# `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`.
459-
# Ensure that `dtype` is properly configured.
460-
dtype = kwargs.pop("dtype", None)
461-
config = set_dtype_in_config(config, dtype)
462-
463-
config["config"] = {**config["config"], **kwargs}
464-
return keras.saving.deserialize_keras_object(config)
465-
466-
467457
def check_config_class(config):
468458
"""Validate a preset is being loaded on the correct class."""
469459
registered_name = config["registered_name"]
@@ -631,26 +621,26 @@ def check_backbone_class(self):
631621
return check_config_class(self.config)
632622

633623
def load_backbone(self, cls, load_weights, **kwargs):
634-
backbone = load_serialized_object(self.config, **kwargs)
624+
backbone = self._load_serialized_object(self.config, **kwargs)
635625
if load_weights:
636626
jax_memory_cleanup(backbone)
637627
backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE))
638628
return backbone
639629

640630
def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs):
641631
tokenizer_config = load_json(self.preset, config_file)
642-
tokenizer = load_serialized_object(tokenizer_config, **kwargs)
632+
tokenizer = self._load_serialized_object(tokenizer_config, **kwargs)
643633
if hasattr(tokenizer, "load_preset_assets"):
644634
tokenizer.load_preset_assets(self.preset)
645635
return tokenizer
646636

647637
def load_audio_converter(self, cls, **kwargs):
648638
converter_config = load_json(self.preset, AUDIO_CONVERTER_CONFIG_FILE)
649-
return load_serialized_object(converter_config, **kwargs)
639+
return self._load_serialized_object(converter_config, **kwargs)
650640

651641
def load_image_converter(self, cls, **kwargs):
652642
converter_config = load_json(self.preset, IMAGE_CONVERTER_CONFIG_FILE)
653-
return load_serialized_object(converter_config, **kwargs)
643+
return self._load_serialized_object(converter_config, **kwargs)
654644

655645
def load_task(self, cls, load_weights, load_task_weights, **kwargs):
656646
# If there is no `task.json` or it's for the wrong class delegate to the
@@ -671,7 +661,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
671661
backbone_config = task_config["config"]["backbone"]["config"]
672662
backbone_config = {**backbone_config, **backbone_kwargs}
673663
task_config["config"]["backbone"]["config"] = backbone_config
674-
task = load_serialized_object(task_config, **kwargs)
664+
task = self._load_serialized_object(task_config, **kwargs)
675665
if task.preprocessor and hasattr(
676666
task.preprocessor, "load_preset_assets"
677667
):
@@ -699,11 +689,20 @@ def load_preprocessor(
699689
if not issubclass(check_config_class(preprocessor_json), cls):
700690
return super().load_preprocessor(cls, **kwargs)
701691
# We found a `preprocessing.json` with a complete config for our class.
702-
preprocessor = load_serialized_object(preprocessor_json, **kwargs)
692+
preprocessor = self._load_serialized_object(preprocessor_json, **kwargs)
703693
if hasattr(preprocessor, "load_preset_assets"):
704694
preprocessor.load_preset_assets(self.preset)
705695
return preprocessor
706696

697+
def _load_serialized_object(self, config, **kwargs):
698+
# `dtype` in config might be a serialized `DTypePolicy` or
699+
# `DTypePolicyMap`. Ensure that `dtype` is properly configured.
700+
dtype = kwargs.pop("dtype", None)
701+
config = set_dtype_in_config(config, dtype)
702+
703+
config["config"] = {**config["config"], **kwargs}
704+
return keras.saving.deserialize_keras_object(config)
705+
707706

708707
class KerasPresetSaver:
709708
def __init__(self, preset_dir):

keras_hub/src/utils/preset_utils_test.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
from keras_hub.src.models.bert.bert_backbone import BertBackbone
1212
from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer
1313
from keras_hub.src.tests.test_case import TestCase
14-
from keras_hub.src.utils.keras_utils import has_quantization_support
1514
from keras_hub.src.utils.preset_utils import CONFIG_FILE
16-
from keras_hub.src.utils.preset_utils import load_serialized_object
1715
from keras_hub.src.utils.preset_utils import upload_preset
1816

1917

@@ -88,18 +86,3 @@ def test_upload_with_invalid_json(self):
8886
# Verify error handling.
8987
with self.assertRaisesRegex(ValueError, "is an invalid json"):
9088
upload_preset("kaggle://test/test/test", local_preset_dir)
91-
92-
@parameterized.named_parameters(
93-
("gemma2_2b_en", "gemma2_2b_en", "bfloat16", False),
94-
("llama2_7b_en_int8", "llama2_7b_en_int8", "bfloat16", True),
95-
)
96-
@pytest.mark.extra_large
97-
def test_load_serialized_object(self, preset, dtype, is_quantized):
98-
if is_quantized and not has_quantization_support():
99-
self.skipTest("This version of Keras doesn't support quantization.")
100-
101-
model = load_serialized_object(preset, dtype=dtype)
102-
if is_quantized:
103-
self.assertEqual(model.dtype_policy.name, "map_bfloat16")
104-
else:
105-
self.assertEqual(model.dtype_policy.name, "bfloat16")
File renamed without changes.
File renamed without changes.
File renamed without changes.

tools/convert_legacy_presets.py

Lines changed: 0 additions & 104 deletions
This file was deleted.

0 commit comments

Comments
 (0)