@@ -175,6 +175,11 @@ def __init__(self,
175
175
self .embed_positions = None
176
176
self .rotary_inv_freq = None
177
177
self .embed_positions_for_gpt_attention = None
178
+
179
+ self .embed_positions_local = None
180
+ self .rotary_inv_freq_local = None
181
+ self .embed_positions_for_gpt_attention_local = None
182
+
178
183
# long rope const parameters
179
184
self .long_rope_embed_positions = None
180
185
self .long_rope_rotary_inv_freq = None
@@ -186,10 +191,16 @@ def fill_attention_const_params_for_rope(
186
191
self ,
187
192
embed_positions : Tensor = None ,
188
193
rotary_inv_freq : Tensor = None ,
189
- embed_positions_for_gpt_attention : Tensor = None ):
194
+ embed_positions_for_gpt_attention : Tensor = None ,
195
+ embed_positions_local : Tensor = None ,
196
+ rotary_inv_freq_local : Tensor = None ,
197
+ embed_positions_for_gpt_attention_local : Tensor = None ):
190
198
self .embed_positions = embed_positions
191
199
self .rotary_inv_freq = rotary_inv_freq
192
200
self .embed_positions_for_gpt_attention = embed_positions_for_gpt_attention
201
+ self .embed_positions_local = embed_positions_local
202
+ self .rotary_inv_freq_local = rotary_inv_freq_local
203
+ self .embed_positions_for_gpt_attention_local = embed_positions_for_gpt_attention_local
193
204
return self
194
205
195
206
def fill_attention_const_params_for_long_rope (
@@ -359,6 +370,7 @@ def __init__(self,
359
370
dtype = None ,
360
371
position_embedding_type = PositionEmbeddingType .learned_absolute ,
361
372
rotary_embedding_base = 10000.0 ,
373
+ rotary_embedding_base_local = 1.0 ,
362
374
rotary_embedding_scaling = None ,
363
375
rotary_embedding_percentage = 1.0 ,
364
376
rope_scaling_short_factors = None ,
@@ -388,7 +400,8 @@ def __init__(self,
388
400
cp_size = 1 ,
389
401
cp_rank = 0 ,
390
402
max_seqlen_for_logn_scaling = 8192 ,
391
- use_logn_scaling = False ):
403
+ use_logn_scaling = False ,
404
+ is_local = False ):
392
405
super ().__init__ ()
393
406
394
407
self .local_layer_idx = local_layer_idx
@@ -417,6 +430,7 @@ def __init__(self,
417
430
self .cp_group = cp_group
418
431
self .cp_size = cp_size
419
432
self .cp_rank = cp_rank
433
+ self .is_local = is_local
420
434
421
435
self .num_layers = num_layers
422
436
self .apply_query_key_layer_scaling = apply_query_key_layer_scaling
@@ -437,6 +451,7 @@ def __init__(self,
437
451
self .max_distance = max_distance
438
452
self .num_buckets = num_buckets
439
453
self .rotary_embedding_base = rotary_embedding_base
454
+ self .rotary_embedding_base_local = rotary_embedding_base_local
440
455
self .rotary_embedding_scaling = rotary_embedding_scaling
441
456
self .rotary_embedding_scale_type = RotaryScalingType .none
442
457
self .rotary_embedding_scale = 1.0
@@ -677,6 +692,29 @@ def create_attention_const_params(model_cls, config):
677
692
dtype = 'float32' ,
678
693
is_buffer = True ))
679
694
695
+ rotary_embedding_base_local = getattr (config , 'rope_local_base_freq' , None )
696
+ if rotary_embedding_base_local is not None :
697
+ embed_positions_local = RopeEmbeddingUtils .create_sinusoidal_positions (
698
+ max_position_embeddings ,
699
+ rotary_embedding_dim ,
700
+ )
701
+ rotary_inv_freq_local , embed_positions_for_gpt_attention_local = RopeEmbeddingUtils .create_sinusoidal_positions_for_attention_plugin (
702
+ max_position_embeddings , rotary_embedding_dim ,
703
+ rotary_embedding_base_local , rotary_embedding_scale ,
704
+ rotary_embedding_scale_type , rotary_embedding_scaling )
705
+ model_cls .register_parameter (
706
+ 'embed_positions_local' ,
707
+ Parameter (embed_positions_local , dtype = 'float32' , is_buffer = True ))
708
+ model_cls .register_parameter (
709
+ 'rotary_inv_freq_local' ,
710
+ Parameter (rotary_inv_freq_local , dtype = 'float32' , is_buffer = True ))
711
+ model_cls .register_parameter (
712
+ 'embed_positions_for_gpt_attention_local' ,
713
+ Parameter (embed_positions_for_gpt_attention_local ,
714
+ dtype = 'float32' ,
715
+ is_buffer = True ))
716
+
717
+
680
718
@staticmethod
681
719
def fill_attention_params (model_cls , attention_params ):
682
720
if model_cls .position_embedding_type .is_rope ():
@@ -695,7 +733,10 @@ def fill_attention_params(model_cls, attention_params):
695
733
return attention_params .fill_attention_const_params_for_rope (
696
734
model_cls .embed_positions .value ,
697
735
model_cls .rotary_inv_freq .value ,
698
- model_cls .embed_positions_for_gpt_attention .value )
736
+ model_cls .embed_positions_for_gpt_attention .value ,
737
+ model_cls .embed_positions_local .value ,
738
+ model_cls .rotary_inv_freq_local .value ,
739
+ model_cls .embed_positions_for_gpt_attention_local .value )
699
740
# Fill nothing.
700
741
return attention_params
701
742
@@ -1020,6 +1061,9 @@ def compute_cross_kv(encoder_output):
1020
1061
# Rotary cos/sin cache.
1021
1062
rotary_cos_sin = getattr (attention_params ,
1022
1063
"embed_positions_for_gpt_attention" , None )
1064
+ rotary_inv_freq_local = getattr (attention_params , "rotary_inv_freq_local" , None )
1065
+ rotary_cos_sin_local = getattr (attention_params ,
1066
+ "embed_positions_for_gpt_attention_local" , None )
1023
1067
1024
1068
long_rope_rotary_inv_freq = getattr (attention_params ,
1025
1069
"long_rope_rotary_inv_freq" ,
@@ -1037,6 +1081,9 @@ def compute_cross_kv(encoder_output):
1037
1081
assert (rotary_inv_freq is not None ) and (
1038
1082
rotary_cos_sin is not None
1039
1083
), "rotary_inv_freq and embed_positions_for_gpt_attention must be provided."
1084
+ assert (rotary_inv_freq_local is not None ) and (
1085
+ rotary_cos_sin_local is not None
1086
+ ), "rotary_inv_freq_local and embed_positions_for_gpt_attention_local must be provided."
1040
1087
if self .position_embedding_type == PositionEmbeddingType .long_rope :
1041
1088
assert long_rope_rotary_inv_freq is not None
1042
1089
assert long_rope_rotary_cos_sin is not None
@@ -1062,7 +1109,7 @@ def compute_cross_kv(encoder_output):
1062
1109
hidden_size_per_head = self .attention_head_size ,
1063
1110
q_scaling = self .q_scaling ,
1064
1111
rotary_embedding_dim = self .rotary_embedding_dim ,
1065
- rotary_embedding_base = self .rotary_embedding_base ,
1112
+ rotary_embedding_base = self .rotary_embedding_base if not self . is_local else self . rotary_embedding_base_local ,
1066
1113
rotary_embedding_scale_type = self .rotary_embedding_scale_type ,
1067
1114
rotary_embedding_short_m_scale = attention_params .short_mscale ,
1068
1115
rotary_embedding_long_m_scale = attention_params .long_mscale ,
@@ -1071,8 +1118,8 @@ def compute_cross_kv(encoder_output):
1071
1118
rotary_embedding_original_max_positions = self .
1072
1119
original_max_position_embeddings ,
1073
1120
position_embedding_type = self .position_embedding_type ,
1074
- rotary_inv_freq = rotary_inv_freq ,
1075
- rotary_cos_sin = rotary_cos_sin ,
1121
+ rotary_inv_freq = rotary_inv_freq if not self . is_local else rotary_inv_freq_local ,
1122
+ rotary_cos_sin = rotary_cos_sin if not self . is_local else rotary_cos_sin_local ,
1076
1123
kv_orig_quant_scale = kv_orig_quant_scale ,
1077
1124
kv_quant_orig_scale = kv_quant_orig_scale ,
1078
1125
attention_output_orig_quant_scale = self .
0 commit comments