@@ -175,6 +175,13 @@ def __init__(self,
175
175
self .embed_positions = None
176
176
self .rotary_inv_freq = None
177
177
self .embed_positions_for_gpt_attention = None
178
+
179
+ # auxiliary params to support models with non-homegeneous attn layers requiring
180
+ # a different set of rope params. e.g. Gemma3.
181
+ self .embed_positions_local = None
182
+ self .rotary_inv_freq_local = None
183
+ self .embed_positions_for_gpt_attention_local = None
184
+
178
185
# long rope const parameters
179
186
self .long_rope_embed_positions = None
180
187
self .long_rope_rotary_inv_freq = None
@@ -186,10 +193,16 @@ def fill_attention_const_params_for_rope(
186
193
self ,
187
194
embed_positions : Tensor = None ,
188
195
rotary_inv_freq : Tensor = None ,
189
- embed_positions_for_gpt_attention : Tensor = None ):
196
+ embed_positions_for_gpt_attention : Tensor = None ,
197
+ embed_positions_local : Tensor = None ,
198
+ rotary_inv_freq_local : Tensor = None ,
199
+ embed_positions_for_gpt_attention_local : Tensor = None ):
190
200
self .embed_positions = embed_positions
191
201
self .rotary_inv_freq = rotary_inv_freq
192
202
self .embed_positions_for_gpt_attention = embed_positions_for_gpt_attention
203
+ self .embed_positions_local = embed_positions_local
204
+ self .rotary_inv_freq_local = rotary_inv_freq_local
205
+ self .embed_positions_for_gpt_attention_local = embed_positions_for_gpt_attention_local
193
206
return self
194
207
195
208
def fill_attention_const_params_for_long_rope (
@@ -359,6 +372,7 @@ def __init__(self,
359
372
dtype = None ,
360
373
position_embedding_type = PositionEmbeddingType .learned_absolute ,
361
374
rotary_embedding_base = 10000.0 ,
375
+ rotary_embedding_base_local = 1.0 ,
362
376
rotary_embedding_scaling = None ,
363
377
rotary_embedding_percentage = 1.0 ,
364
378
rope_scaling_short_factors = None ,
@@ -388,7 +402,8 @@ def __init__(self,
388
402
cp_size = 1 ,
389
403
cp_rank = 0 ,
390
404
max_seqlen_for_logn_scaling = 8192 ,
391
- use_logn_scaling = False ):
405
+ use_logn_scaling = False ,
406
+ is_local = False ):
392
407
super ().__init__ ()
393
408
394
409
self .local_layer_idx = local_layer_idx
@@ -417,6 +432,7 @@ def __init__(self,
417
432
self .cp_group = cp_group
418
433
self .cp_size = cp_size
419
434
self .cp_rank = cp_rank
435
+ self .is_local = is_local
420
436
421
437
self .num_layers = num_layers
422
438
self .apply_query_key_layer_scaling = apply_query_key_layer_scaling
@@ -437,6 +453,7 @@ def __init__(self,
437
453
self .max_distance = max_distance
438
454
self .num_buckets = num_buckets
439
455
self .rotary_embedding_base = rotary_embedding_base
456
+ self .rotary_embedding_base_local = rotary_embedding_base_local
440
457
self .rotary_embedding_scaling = rotary_embedding_scaling
441
458
self .rotary_embedding_scale_type = RotaryScalingType .none
442
459
self .rotary_embedding_scale = 1.0
@@ -656,26 +673,45 @@ def create_attention_const_params(model_cls, config):
656
673
model_cls .short_mscale = short_mscale
657
674
model_cls .long_mscale = long_mscale
658
675
else :
659
- # Rotary const weights.
660
- embed_positions = RopeEmbeddingUtils .create_sinusoidal_positions (
661
- max_position_embeddings ,
662
- rotary_embedding_dim ,
663
- )
664
- rotary_inv_freq , embed_positions_for_gpt_attention = RopeEmbeddingUtils .create_sinusoidal_positions_for_attention_plugin (
665
- max_position_embeddings , rotary_embedding_dim ,
666
- rotary_embedding_base , rotary_embedding_scale ,
667
- rotary_embedding_scale_type , rotary_embedding_scaling )
668
- model_cls .register_parameter (
669
- 'embed_positions' ,
670
- Parameter (embed_positions , dtype = 'float32' , is_buffer = True ))
671
- model_cls .register_parameter (
672
- 'rotary_inv_freq' ,
673
- Parameter (rotary_inv_freq , dtype = 'float32' , is_buffer = True ))
674
- model_cls .register_parameter (
675
- 'embed_positions_for_gpt_attention' ,
676
- Parameter (embed_positions_for_gpt_attention ,
677
- dtype = 'float32' ,
678
- is_buffer = True ))
676
+
677
+ def register_rope_params (rotary_base , names_to_register ):
678
+ # Rotary const weights.
679
+ embed_positions = RopeEmbeddingUtils .create_sinusoidal_positions (
680
+ max_position_embeddings ,
681
+ rotary_embedding_dim ,
682
+ )
683
+ rotary_inv_freq , embed_positions_for_gpt_attention = RopeEmbeddingUtils .create_sinusoidal_positions_for_attention_plugin (
684
+ max_position_embeddings , rotary_embedding_dim , rotary_base ,
685
+ rotary_embedding_scale , rotary_embedding_scale_type ,
686
+ rotary_embedding_scaling )
687
+ model_cls .register_parameter (
688
+ names_to_register [0 ],
689
+ Parameter (embed_positions , dtype = 'float32' , is_buffer = True ))
690
+ model_cls .register_parameter (
691
+ names_to_register [1 ],
692
+ Parameter (rotary_inv_freq , dtype = 'float32' , is_buffer = True ))
693
+ model_cls .register_parameter (
694
+ names_to_register [2 ],
695
+ Parameter (embed_positions_for_gpt_attention ,
696
+ dtype = 'float32' ,
697
+ is_buffer = True ))
698
+
699
+ register_rope_params (rotary_base = rotary_embedding_base ,
700
+ names_to_register = [
701
+ 'embed_positions' , 'rotary_inv_freq' ,
702
+ 'embed_positions_for_gpt_attention'
703
+ ])
704
+
705
+ # For models with non-homegeneous attention layers requiring a second set of rope params. e.g. Gemma3.
706
+ rotary_embedding_base_local = getattr (config ,
707
+ 'rope_local_base_freq' , None )
708
+ if rotary_embedding_base_local is not None :
709
+ register_rope_params (
710
+ rotary_base = rotary_embedding_base_local ,
711
+ names_to_register = [
712
+ 'embed_positions_local' , 'rotary_inv_freq_local' ,
713
+ 'embed_positions_for_gpt_attention_local'
714
+ ])
679
715
680
716
@staticmethod
681
717
def fill_attention_params (model_cls , attention_params ):
@@ -695,7 +731,15 @@ def fill_attention_params(model_cls, attention_params):
695
731
return attention_params .fill_attention_const_params_for_rope (
696
732
model_cls .embed_positions .value ,
697
733
model_cls .rotary_inv_freq .value ,
698
- model_cls .embed_positions_for_gpt_attention .value )
734
+ model_cls .embed_positions_for_gpt_attention .value ,
735
+ model_cls .embed_positions_local .value if hasattr (
736
+ model_cls , "embed_positions_local" ) else None ,
737
+ model_cls .rotary_inv_freq_local .value if hasattr (
738
+ model_cls , "rotary_inv_freq_local" ) else None ,
739
+ model_cls .embed_positions_for_gpt_attention_local .value
740
+ if hasattr (
741
+ model_cls ,
742
+ "embed_positions_for_gpt_attention_local" ) else None )
699
743
# Fill nothing.
700
744
return attention_params
701
745
@@ -1020,6 +1064,11 @@ def compute_cross_kv(encoder_output):
1020
1064
# Rotary cos/sin cache.
1021
1065
rotary_cos_sin = getattr (attention_params ,
1022
1066
"embed_positions_for_gpt_attention" , None )
1067
+ rotary_inv_freq_local = getattr (attention_params ,
1068
+ "rotary_inv_freq_local" , None )
1069
+ rotary_cos_sin_local = getattr (
1070
+ attention_params , "embed_positions_for_gpt_attention_local" ,
1071
+ None )
1023
1072
1024
1073
long_rope_rotary_inv_freq = getattr (attention_params ,
1025
1074
"long_rope_rotary_inv_freq" ,
@@ -1062,7 +1111,8 @@ def compute_cross_kv(encoder_output):
1062
1111
hidden_size_per_head = self .attention_head_size ,
1063
1112
q_scaling = self .q_scaling ,
1064
1113
rotary_embedding_dim = self .rotary_embedding_dim ,
1065
- rotary_embedding_base = self .rotary_embedding_base ,
1114
+ rotary_embedding_base = self .rotary_embedding_base
1115
+ if not self .is_local else self .rotary_embedding_base_local ,
1066
1116
rotary_embedding_scale_type = self .rotary_embedding_scale_type ,
1067
1117
rotary_embedding_short_m_scale = attention_params .short_mscale ,
1068
1118
rotary_embedding_long_m_scale = attention_params .long_mscale ,
@@ -1071,8 +1121,10 @@ def compute_cross_kv(encoder_output):
1071
1121
rotary_embedding_original_max_positions = self .
1072
1122
original_max_position_embeddings ,
1073
1123
position_embedding_type = self .position_embedding_type ,
1074
- rotary_inv_freq = rotary_inv_freq ,
1075
- rotary_cos_sin = rotary_cos_sin ,
1124
+ rotary_inv_freq = rotary_inv_freq
1125
+ if not self .is_local else rotary_inv_freq_local ,
1126
+ rotary_cos_sin = rotary_cos_sin
1127
+ if not self .is_local else rotary_cos_sin_local ,
1076
1128
kv_orig_quant_scale = kv_orig_quant_scale ,
1077
1129
kv_quant_orig_scale = kv_quant_orig_scale ,
1078
1130
attention_output_orig_quant_scale = self .
0 commit comments