Skip to content

Commit d8b2656

Browse files
james77777778divyashreepathihalli
authored andcommitted
Support kwargs to Backbone.from_preset and fix the dtype forwarding in Task.from_preset (#1742)
1 parent 7a17731 commit d8b2656

File tree

6 files changed

+79
-13
lines changed

6 files changed

+79
-13
lines changed

keras_nlp/src/models/backbone.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from
213213
f"`from_preset` directly on `{preset_cls.__name__}` instead."
214214
)
215215

216-
backbone = load_serialized_object(preset, CONFIG_FILE)
216+
backbone = load_serialized_object(preset, CONFIG_FILE, **kwargs)
217217
if load_weights:
218218
jax_memory_cleanup(backbone)
219219
backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))

keras_nlp/src/models/backbone_test.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from keras_nlp.src.utils.preset_utils import load_config
2828

2929

30-
class TestTask(TestCase):
30+
class TestBackbone(TestCase):
3131
def test_preset_accessors(self):
3232
bert_presets = set(BertBackbone.presets.keys())
3333
gpt2_presets = set(GPT2Backbone.presets.keys())
@@ -46,6 +46,22 @@ def test_from_preset(self):
4646
GPT2Backbone,
4747
)
4848

49+
@pytest.mark.large
50+
def test_from_preset_with_kwargs(self):
51+
# Test `dtype`
52+
backbone = Backbone.from_preset(
53+
"bert_tiny_en_uncased", load_weights=False, dtype="bfloat16"
54+
)
55+
self.assertIsInstance(backbone, BertBackbone)
56+
self.assertEqual(backbone.dtype_policy.name, "bfloat16")
57+
58+
# Test kwargs forwarding
59+
backbone = Backbone.from_preset(
60+
"bert_tiny_en_uncased", load_weights=False, dropout=0.5
61+
)
62+
self.assertIsInstance(backbone, BertBackbone)
63+
self.assertAllClose(backbone.dropout, 0.5)
64+
4965
@pytest.mark.large
5066
def test_from_preset_errors(self):
5167
with self.assertRaises(ValueError):

keras_nlp/src/models/task.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -258,13 +258,11 @@ def from_preset(
258258
)
259259
cls = subclasses[0]
260260
# Forward dtype to the backbone.
261-
config_overrides = {}
261+
backbone_kwargs = {}
262262
if "dtype" in kwargs:
263-
config_overrides["dtype"] = kwargs.pop("dtype")
263+
backbone_kwargs = {"dtype": kwargs.pop("dtype")}
264264
backbone = backbone_preset_cls.from_preset(
265-
preset,
266-
load_weights=load_weights,
267-
config_overrides=config_overrides,
265+
preset, load_weights=load_weights, **backbone_kwargs
268266
)
269267
if "preprocessor" in kwargs:
270268
preprocessor = kwargs.pop("preprocessor")

keras_nlp/src/models/task_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,16 @@ def test_from_preset(self):
7171
# TODO: Add a classifier task loading test when there is a classifier
7272
# with new design available on Kaggle.
7373

74+
@pytest.mark.large
75+
def test_from_preset_with_kwargs(self):
76+
# Test `dtype`
77+
model = CausalLM.from_preset(
78+
"gpt2_base_en", load_weights=False, dtype="bfloat16"
79+
)
80+
self.assertIsInstance(model, GPT2CausalLM)
81+
self.assertEqual(model.dtype_policy.name, "bfloat16")
82+
self.assertEqual(model.backbone.dtype_policy.name, "bfloat16")
83+
7484
@pytest.mark.large
7585
def test_from_preset_errors(self):
7686
with self.assertRaises(ValueError):

keras_nlp/src/utils/preset_utils.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -561,13 +561,16 @@ def check_format(preset):
561561
return "keras"
562562

563563

564-
def load_serialized_object(
565-
preset,
566-
config_file=CONFIG_FILE,
567-
config_overrides={},
568-
):
564+
def load_serialized_object(preset, config_file=CONFIG_FILE, **kwargs):
565+
kwargs = kwargs or {}
569566
config = load_config(preset, config_file)
570-
config["config"] = {**config["config"], **config_overrides}
567+
568+
# `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`.
569+
# Ensure that `dtype` is properly configured.
570+
dtype = kwargs.pop("dtype", None)
571+
config = set_dtype_in_config(config, dtype)
572+
573+
config["config"] = {**config["config"], **kwargs}
571574
return keras.saving.deserialize_keras_object(config)
572575

573576

@@ -590,3 +593,25 @@ def jax_memory_cleanup(layer):
590593
for weight in layer.weights:
591594
if getattr(weight, "_value", None) is not None:
592595
weight._value.delete()
596+
597+
598+
def set_dtype_in_config(config, dtype=None):
599+
if dtype is None:
600+
return config
601+
602+
config = config.copy()
603+
if "dtype" not in config["config"]:
604+
# Forward `dtype` to the config.
605+
config["config"]["dtype"] = dtype
606+
elif (
607+
"dtype" in config["config"]
608+
and isinstance(config["config"]["dtype"], dict)
609+
and "DTypePolicyMap" in config["config"]["dtype"]["class_name"]
610+
):
611+
# If it is `DTypePolicyMap` in `config`, forward `dtype` as its default
612+
# policy.
613+
policy_map_config = config["config"]["dtype"]["config"]
614+
policy_map_config["default_policy"] = dtype
615+
for k in policy_map_config["policy_map"].keys():
616+
policy_map_config["policy_map"][k]["config"]["source_name"] = dtype
617+
return config

keras_nlp/src/utils/preset_utils_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323
from keras_nlp.src.models import BertBackbone
2424
from keras_nlp.src.models import BertTokenizer
2525
from keras_nlp.src.tests.test_case import TestCase
26+
from keras_nlp.src.utils.keras_utils import has_quantization_support
2627
from keras_nlp.src.utils.preset_utils import CONFIG_FILE
2728
from keras_nlp.src.utils.preset_utils import METADATA_FILE
2829
from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
2930
from keras_nlp.src.utils.preset_utils import check_format
31+
from keras_nlp.src.utils.preset_utils import load_serialized_object
3032

3133

3234
class PresetUtilsTest(TestCase):
@@ -113,3 +115,18 @@ def test_incorrect_metadata(self):
113115

114116
with self.assertRaisesRegex(ValueError, "doesn't have `keras_version`"):
115117
check_format(preset_dir)
118+
119+
@parameterized.named_parameters(
120+
("gemma2_2b_en", "gemma2_2b_en", "bfloat16", False),
121+
("llama2_7b_en_int8", "llama2_7b_en_int8", "bfloat16", True),
122+
)
123+
@pytest.mark.extra_large
124+
def test_load_serialized_object(self, preset, dtype, is_quantized):
125+
if is_quantized and not has_quantization_support():
126+
self.skipTest("This version of Keras doesn't support quantization.")
127+
128+
model = load_serialized_object(preset, dtype=dtype)
129+
if is_quantized:
130+
self.assertEqual(model.dtype_policy.name, "map_bfloat16")
131+
else:
132+
self.assertEqual(model.dtype_policy.name, "bfloat16")

0 commit comments

Comments
 (0)