@@ -47,6 +47,34 @@ def test_lora_fine_tuning(self):
47
47
new_out = new_backbone (input_data )
48
48
self .assertAllClose (ref_out , new_out )
49
49
50
+ def test_lora_fine_tuning_target_names (self ):
51
+ # Set up backbone and preprocessor.
52
+ backbone = GemmaBackbone (** self ._init_kwargs )
53
+ backbone .enable_lora (4 , target_names = ["query" ])
54
+ # 4 layers, 2 weights per layer
55
+ self .assertLen (backbone .trainable_weights , 2 * 2 )
56
+ self .assertLen (backbone .non_trainable_weights , 20 )
57
+ input_data = {
58
+ "token_ids" : np .ones ((2 , 5 ), dtype = "int32" ),
59
+ "padding_mask" : np .ones ((2 , 5 ), dtype = "int32" ),
60
+ }
61
+ targets = np .random .normal (size = (2 , 5 , self ._init_kwargs ["hidden_dim" ]))
62
+
63
+ # Test fine-tuning
64
+ backbone .compile (optimizer = "sgd" , loss = "mse" )
65
+ backbone .fit (input_data , targets , epochs = 1 )
66
+
67
+ # Test saving and reloading.
68
+ temp_filepath = os .path .join (
69
+ self .get_temp_dir (), "lora_model.weights.h5"
70
+ )
71
+ backbone .save_weights (temp_filepath )
72
+ new_backbone = GemmaBackbone (** self ._init_kwargs )
73
+ new_backbone .load_weights (temp_filepath )
74
+ ref_out = backbone (input_data )
75
+ new_out = new_backbone (input_data )
76
+ self .assertAllClose (ref_out , new_out )
77
+
50
78
def test_lora_saving_and_reloading (self ):
51
79
backbone = GemmaBackbone (** self ._init_kwargs )
52
80
initial_model_filepath = os .path .join (
0 commit comments