Skip to content

Commit 1c5d9b3

Browse files
authored
Add configurable lora_alpha parameter for LoRA in multiple Keras layers (#21139)
* feat: Add alpha parameter to enable_lora Adds an alpha scaling parameter to LoRA layers, defaulting to rank for backward compatibility. * feat: Add lora_alpha tests to Dense, Embedding, and EinsumDense layers * fix: Fix LoRA test failures by using ops to do numpy conversion * fix: remove .numpy() in LoRA tests * docs: Apply backticks to keywords per review Updated docstrings to enclose parameters like 'alpha' and 'rank' in backticks as requested in PR review.
1 parent 7ed8edb commit 1c5d9b3

File tree

8 files changed

+227
-23
lines changed

8 files changed

+227
-23
lines changed

keras/src/layers/convolutional/base_conv.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ class BaseConv(Layer):
8080
computation cost of fine-tuning large dense layers.
8181
You can also enable LoRA on an existing layer by calling
8282
`layer.enable_lora(rank)`.
83+
lora_alpha: Optional integer. If set, this parameter scales the
84+
low-rank adaptation delta (computed as the product of two lower-rank
85+
trainable matrices) during the forward pass. The delta is scaled by
86+
`lora_alpha / lora_rank`, allowing you to fine-tune the strength of
87+
the LoRA adjustment independently of `lora_rank`.
8388
"""
8489

8590
def __init__(
@@ -102,6 +107,7 @@ def __init__(
102107
kernel_constraint=None,
103108
bias_constraint=None,
104109
lora_rank=None,
110+
lora_alpha=None,
105111
**kwargs,
106112
):
107113
super().__init__(activity_regularizer=activity_regularizer, **kwargs)
@@ -124,6 +130,7 @@ def __init__(
124130
self.kernel_constraint = constraints.get(kernel_constraint)
125131
self.bias_constraint = constraints.get(bias_constraint)
126132
self.lora_rank = lora_rank
133+
self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank
127134
self.lora_enabled = False
128135
self.input_spec = InputSpec(min_ndim=self.rank + 2)
129136
self.data_format = self.data_format
@@ -215,7 +222,7 @@ def build(self, input_shape):
215222
self.bias = None
216223
self.built = True
217224
if self.lora_rank:
218-
self.enable_lora(self.lora_rank)
225+
self.enable_lora(self.lora_rank, lora_alpha=self.lora_alpha)
219226

220227
@property
221228
def kernel(self):
@@ -224,9 +231,9 @@ def kernel(self):
224231
"You must build the layer before accessing `kernel`."
225232
)
226233
if self.lora_enabled:
227-
return self._kernel + ops.matmul(
228-
self.lora_kernel_a, self.lora_kernel_b
229-
)
234+
return self._kernel + (
235+
self.lora_alpha / self.lora_rank
236+
) * ops.matmul(self.lora_kernel_a, self.lora_kernel_b)
230237
return self._kernel
231238

232239
def convolution_op(self, inputs, kernel):
@@ -268,7 +275,11 @@ def compute_output_shape(self, input_shape):
268275
)
269276

270277
def enable_lora(
271-
self, rank, a_initializer="he_uniform", b_initializer="zeros"
278+
self,
279+
rank,
280+
lora_alpha=None,
281+
a_initializer="he_uniform",
282+
b_initializer="zeros",
272283
):
273284
if self.kernel_constraint:
274285
raise ValueError(
@@ -301,6 +312,7 @@ def enable_lora(
301312
self._tracker.lock()
302313
self.lora_enabled = True
303314
self.lora_rank = rank
315+
self.lora_alpha = lora_alpha if lora_alpha is not None else rank
304316

305317
def save_own_variables(self, store):
306318
# Do nothing if the layer isn't yet built
@@ -363,6 +375,7 @@ def get_config(self):
363375
)
364376
if self.lora_rank:
365377
config["lora_rank"] = self.lora_rank
378+
config["lora_alpha"] = self.lora_alpha
366379
return config
367380

368381
def _check_load_own_variables(self, store):

keras/src/layers/convolutional/conv_test.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from keras.src import constraints
1010
from keras.src import layers
1111
from keras.src import models
12+
from keras.src import ops
1213
from keras.src import saving
1314
from keras.src import testing
1415

@@ -735,6 +736,51 @@ def call(self, x):
735736
model.conv2d.lora_kernel_a.path, "mymodel/conv2d/lora_kernel_a"
736737
)
737738

739+
@pytest.mark.requires_trainable_backend
740+
def test_enable_lora_with_alpha(self):
741+
# Create a `Conv2D` layer with a small kernel for simplicity.
742+
layer = layers.Conv2D(filters=3, kernel_size=(2, 2), padding="valid")
743+
# Use a fixed input shape: batch size 1, height=4, width=4, channels=3.
744+
input_shape = (1, 4, 4, 3)
745+
layer.build(input_shape)
746+
747+
# Set the base kernel to known, deterministic values.
748+
base_kernel = np.linspace(
749+
0, 1, num=np.prod(layer.kernel.shape), dtype=np.float32
750+
)
751+
base_kernel = base_kernel.reshape(layer.kernel.shape)
752+
layer.kernel.assign(base_kernel)
753+
754+
# Enable LoRA with `rank`=2 and a custom `lora_alpha` value (e.g. 3.0).
755+
layer.enable_lora(rank=2, lora_alpha=3.0)
756+
self.assertEqual(layer.lora_rank, 2)
757+
self.assertEqual(layer.lora_alpha, 3.0)
758+
759+
# For `Conv2D`, assume the LoRA weights have shapes:
760+
# `lora_kernel_a`: (kernel_height, kernel_width, in_channels, rank)
761+
# `lora_kernel_b`: (rank, out_channels)
762+
lora_a_shape = layer.lora_kernel_a.shape
763+
lora_b_shape = layer.lora_kernel_b.shape
764+
765+
# Assign known constant values to LoRA weights.
766+
lora_a = np.full(lora_a_shape, 0.1, dtype=np.float32)
767+
lora_b = np.full(lora_b_shape, 0.2, dtype=np.float32)
768+
layer.lora_kernel_a.assign(lora_a)
769+
layer.lora_kernel_b.assign(lora_b)
770+
771+
# Compute the expected delta.
772+
# Flatten `lora_kernel_a` to shape (-1, `rank`),
773+
# multiply with `lora_kernel_b`,
774+
# then reshape to the kernel's shape.
775+
scaling = 3.0 / 2 # `lora_alpha / lora_rank`
776+
delta = np.matmul(lora_a.reshape(-1, 2), lora_b)
777+
delta = delta.reshape(base_kernel.shape)
778+
expected_effective_kernel = base_kernel + scaling * delta
779+
780+
# Compare the effective kernel computed via the property.
781+
actual_effective_kernel = ops.convert_to_numpy(layer.kernel)
782+
self.assertAllClose(actual_effective_kernel, expected_effective_kernel)
783+
738784
@pytest.mark.requires_trainable_backend
739785
def test_lora_rank_argument(self):
740786
self.run_layer_test(

keras/src/layers/core/dense.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ class Dense(Layer):
5757
computation cost of fine-tuning large dense layers.
5858
You can also enable LoRA on an existing
5959
`Dense` layer by calling `layer.enable_lora(rank)`.
60+
lora_alpha: Optional integer. If set, this parameter scales the
61+
low-rank adaptation delta (computed as the product of two lower-rank
62+
trainable matrices) during the forward pass. The delta is scaled by
63+
`lora_alpha / lora_rank`, allowing you to fine-tune the strength of
64+
the LoRA adjustment independently of `lora_rank`.
6065
6166
Input shape:
6267
N-D tensor with shape: `(batch_size, ..., input_dim)`.
@@ -82,6 +87,7 @@ def __init__(
8287
kernel_constraint=None,
8388
bias_constraint=None,
8489
lora_rank=None,
90+
lora_alpha=None,
8591
**kwargs,
8692
):
8793
super().__init__(activity_regularizer=activity_regularizer, **kwargs)
@@ -95,6 +101,7 @@ def __init__(
95101
self.kernel_constraint = constraints.get(kernel_constraint)
96102
self.bias_constraint = constraints.get(bias_constraint)
97103
self.lora_rank = lora_rank
104+
self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank
98105
self.lora_enabled = False
99106
self.input_spec = InputSpec(min_ndim=2)
100107
self.supports_masking = True
@@ -135,9 +142,9 @@ def kernel(self):
135142
"You must build the layer before accessing `kernel`."
136143
)
137144
if self.lora_enabled:
138-
return self._kernel + ops.matmul(
139-
self.lora_kernel_a, self.lora_kernel_b
140-
)
145+
return self._kernel + (
146+
self.lora_alpha / self.lora_rank
147+
) * ops.matmul(self.lora_kernel_a, self.lora_kernel_b)
141148
return self._kernel
142149

143150
def call(self, inputs, training=None):
@@ -154,7 +161,11 @@ def compute_output_shape(self, input_shape):
154161
return tuple(output_shape)
155162

156163
def enable_lora(
157-
self, rank, a_initializer="he_uniform", b_initializer="zeros"
164+
self,
165+
rank,
166+
lora_alpha=None,
167+
a_initializer="he_uniform",
168+
b_initializer="zeros",
158169
):
159170
if self.kernel_constraint:
160171
raise ValueError(
@@ -187,6 +198,7 @@ def enable_lora(
187198
self._tracker.lock()
188199
self.lora_enabled = True
189200
self.lora_rank = rank
201+
self.lora_alpha = lora_alpha if lora_alpha is not None else rank
190202

191203
def save_own_variables(self, store):
192204
# Do nothing if the layer isn't yet built
@@ -261,6 +273,7 @@ def get_config(self):
261273
}
262274
if self.lora_rank:
263275
config["lora_rank"] = self.lora_rank
276+
config["lora_alpha"] = self.lora_alpha
264277
return {**base_config, **config}
265278

266279
def _check_load_own_variables(self, store):
@@ -410,7 +423,7 @@ def grad_fn(*args, upstream=None):
410423
if self.lora_enabled:
411424
lora_x = ops.matmul(inputs, self.lora_kernel_a)
412425
lora_x = ops.matmul(lora_x, self.lora_kernel_b)
413-
x = ops.add(x, lora_x)
426+
x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x)
414427
if self.bias is not None:
415428
x = ops.add(x, self.bias)
416429
if self.activation is not None:
@@ -544,7 +557,8 @@ def _get_kernel_with_merged_lora(self):
544557
kernel_value = ops.divide(kernel_value, kernel_scale)
545558
kernel_value = ops.add(
546559
kernel_value,
547-
ops.matmul(self.lora_kernel_a, self.lora_kernel_b),
560+
(self.lora_alpha / self.lora_rank)
561+
* ops.matmul(self.lora_kernel_a, self.lora_kernel_b),
548562
)
549563
kernel_value, kernel_scale = quantizers.abs_max_quantize(
550564
kernel_value, axis=0, to_numpy=True

keras/src/layers/core/dense_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,32 @@ def test_enable_lora(self):
271271
model.load_weights(temp_filepath)
272272
self.assertAllClose(model.predict(x), new_model.predict(x))
273273

274+
@pytest.mark.requires_trainable_backend
275+
def test_enable_lora_with_alpha(self):
276+
# Create a `Dense` layer and build it.
277+
layer = layers.Dense(units=8)
278+
layer.build((None, 4))
279+
280+
# Enable LoRA with `rank`=2 and `lora_alpha`=3.0.
281+
layer.enable_lora(2, lora_alpha=3.0)
282+
self.assertEqual(layer.lora_rank, 2)
283+
self.assertEqual(layer.lora_alpha, 3.0)
284+
285+
# Manually compute the expected effective kernel:
286+
# `effective_kernel_expected` = `base_kernel` +
287+
# `lora_alpha / lora_rank` * `lora_kernel_a @ lora_kernel_b`
288+
base_kernel = ops.convert_to_numpy(layer._kernel)
289+
lora_update = np.matmul(
290+
ops.convert_to_numpy(layer.lora_kernel_a),
291+
ops.convert_to_numpy(layer.lora_kernel_b),
292+
)
293+
effective_kernel_expected = base_kernel + (3.0 / 2) * lora_update
294+
295+
# Verify that the effective kernel matches expectation.
296+
self.assertAllClose(
297+
ops.convert_to_numpy(layer.kernel), effective_kernel_expected
298+
)
299+
274300
@pytest.mark.requires_trainable_backend
275301
def test_lora_weight_name(self):
276302
class MyModel(models.Model):

keras/src/layers/core/einsum_dense.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ class EinsumDense(Layer):
5858
computation cost of fine-tuning large dense layers.
5959
You can also enable LoRA on an existing
6060
`EinsumDense` layer by calling `layer.enable_lora(rank)`.
61+
lora_alpha: Optional integer. If set, this parameter scales the
62+
low-rank adaptation delta (computed as the product of two lower-rank
63+
trainable matrices) during the forward pass. The delta is scaled by
64+
`lora_alpha / lora_rank`, allowing you to fine-tune the strength of
65+
the LoRA adjustment independently of `lora_rank`.
6166
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
6267
6368
Examples:
@@ -125,6 +130,7 @@ def __init__(
125130
kernel_constraint=None,
126131
bias_constraint=None,
127132
lora_rank=None,
133+
lora_alpha=None,
128134
**kwargs,
129135
):
130136
super().__init__(**kwargs)
@@ -142,6 +148,7 @@ def __init__(
142148
self.kernel_constraint = constraints.get(kernel_constraint)
143149
self.bias_constraint = constraints.get(bias_constraint)
144150
self.lora_rank = lora_rank
151+
self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank
145152
self.lora_enabled = False
146153

147154
def build(self, input_shape):
@@ -184,7 +191,7 @@ def build(self, input_shape):
184191
self.bias = None
185192
self.built = True
186193
if self.lora_rank:
187-
self.enable_lora(self.lora_rank)
194+
self.enable_lora(self.lora_rank, lora_alpha=self.lora_alpha)
188195

189196
@property
190197
def kernel(self):
@@ -193,9 +200,9 @@ def kernel(self):
193200
"You must build the layer before accessing `kernel`."
194201
)
195202
if self.lora_enabled:
196-
return self._kernel + ops.matmul(
197-
self.lora_kernel_a, self.lora_kernel_b
198-
)
203+
return self._kernel + (
204+
self.lora_alpha / self.lora_rank
205+
) * ops.matmul(self.lora_kernel_a, self.lora_kernel_b)
199206
return self._kernel
200207

201208
def compute_output_shape(self, _):
@@ -210,7 +217,11 @@ def call(self, inputs, training=None):
210217
return x
211218

212219
def enable_lora(
213-
self, rank, a_initializer="he_uniform", b_initializer="zeros"
220+
self,
221+
rank,
222+
lora_alpha=None,
223+
a_initializer="he_uniform",
224+
b_initializer="zeros",
214225
):
215226
if self.kernel_constraint:
216227
raise ValueError(
@@ -243,6 +254,7 @@ def enable_lora(
243254
self._tracker.lock()
244255
self.lora_enabled = True
245256
self.lora_rank = rank
257+
self.lora_alpha = lora_alpha if lora_alpha is not None else rank
246258

247259
def save_own_variables(self, store):
248260
# Do nothing if the layer isn't yet built
@@ -321,6 +333,7 @@ def get_config(self):
321333
}
322334
if self.lora_rank:
323335
config["lora_rank"] = self.lora_rank
336+
config["lora_alpha"] = self.lora_alpha
324337
return {**base_config, **config}
325338

326339
def _check_load_own_variables(self, store):
@@ -525,7 +538,7 @@ def grad_fn(*args, upstream=None):
525538
if self.lora_enabled:
526539
lora_x = ops.einsum(self.equation, inputs, self.lora_kernel_a)
527540
lora_x = ops.matmul(lora_x, self.lora_kernel_b)
528-
x = ops.add(x, lora_x)
541+
x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x)
529542
if self.bias is not None:
530543
x += self.bias
531544
if self.activation is not None:
@@ -700,7 +713,8 @@ def _argsort(seq):
700713
kernel_value = ops.divide(kernel_value, kernel_scale)
701714
kernel_value = ops.add(
702715
kernel_value,
703-
ops.matmul(self.lora_kernel_a, self.lora_kernel_b),
716+
(self.lora_alpha / self.lora_rank)
717+
* ops.matmul(self.lora_kernel_a, self.lora_kernel_b),
704718
)
705719
kernel_value, kernel_scale = quantizers.abs_max_quantize(
706720
kernel_value, axis=self._kernel_reduced_axes, to_numpy=True

0 commit comments

Comments
 (0)