Skip to content

Commit 6ac9222

Browse files
Enable LoRA target names arg (#2166)
* enable lora target names arg * Update backbone.py
1 parent 9072586 commit 6ac9222

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

keras_hub/src/models/backbone.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,15 +194,15 @@ def get_lora_target_names(self):
194194
"""
195195
return ["query_dense", "value_dense", "query", "value"]
196196

197-
def enable_lora(self, rank):
197+
def enable_lora(self, rank, target_names=None):
198198
"""Enable Lora on the backbone.
199199
200200
Calling this method will freeze all weights on the backbone,
201201
while enabling Lora on the query & value `EinsumDense` layers
202202
of the attention layers.
203203
"""
204-
target_names = self.get_lora_target_names()
205-
204+
if target_names is None:
205+
target_names = self.get_lora_target_names()
206206
self.trainable = True
207207
self._lora_enabled_layers = []
208208
self._lora_rank = rank

keras_hub/src/models/gemma/gemma_lora_test.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,34 @@ def test_lora_fine_tuning(self):
4747
new_out = new_backbone(input_data)
4848
self.assertAllClose(ref_out, new_out)
4949

50+
def test_lora_fine_tuning_target_names(self):
51+
# Set up backbone and preprocessor.
52+
backbone = GemmaBackbone(**self._init_kwargs)
53+
backbone.enable_lora(4, target_names=["query"])
54+
# 4 layers, 2 weights per layer
55+
self.assertLen(backbone.trainable_weights, 2 * 2)
56+
self.assertLen(backbone.non_trainable_weights, 20)
57+
input_data = {
58+
"token_ids": np.ones((2, 5), dtype="int32"),
59+
"padding_mask": np.ones((2, 5), dtype="int32"),
60+
}
61+
targets = np.random.normal(size=(2, 5, self._init_kwargs["hidden_dim"]))
62+
63+
# Test fine-tuning
64+
backbone.compile(optimizer="sgd", loss="mse")
65+
backbone.fit(input_data, targets, epochs=1)
66+
67+
# Test saving and reloading.
68+
temp_filepath = os.path.join(
69+
self.get_temp_dir(), "lora_model.weights.h5"
70+
)
71+
backbone.save_weights(temp_filepath)
72+
new_backbone = GemmaBackbone(**self._init_kwargs)
73+
new_backbone.load_weights(temp_filepath)
74+
ref_out = backbone(input_data)
75+
new_out = new_backbone(input_data)
76+
self.assertAllClose(ref_out, new_out)
77+
5078
def test_lora_saving_and_reloading(self):
5179
backbone = GemmaBackbone(**self._init_kwargs)
5280
initial_model_filepath = os.path.join(

0 commit comments

Comments
 (0)