Skip to content

Commit 1e31450

Browse files
committed
feat: Add Gemma3 text-only model support
1 parent 3a8443f commit 1e31450

File tree

8 files changed

+132
-20
lines changed

8 files changed

+132
-20
lines changed

examples/utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,8 @@ def add_common_args(parser):
337337
parser.add_argument(
338338
'--max_attention_window_size',
339339
type=int,
340-
default=None,
340+
default=[512, 512, 512, 512, 512, 2048, 512, 512, 512, 512, 512, 2048, 512, 512, 512, 512, 512, 2048, 512, 512, 512, 512, 512, 2048, 512, 512],
341+
# default=None,
341342
nargs="+",
342343
help=
343344
'The attention window size that controls the sliding window attention / cyclic kv cache behavior'

tensorrt_llm/layers/attention.py

+53-6
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,11 @@ def __init__(self,
175175
self.embed_positions = None
176176
self.rotary_inv_freq = None
177177
self.embed_positions_for_gpt_attention = None
178+
179+
self.embed_positions_local = None
180+
self.rotary_inv_freq_local = None
181+
self.embed_positions_for_gpt_attention_local = None
182+
178183
# long rope const parameters
179184
self.long_rope_embed_positions = None
180185
self.long_rope_rotary_inv_freq = None
@@ -186,10 +191,16 @@ def fill_attention_const_params_for_rope(
186191
self,
187192
embed_positions: Tensor = None,
188193
rotary_inv_freq: Tensor = None,
189-
embed_positions_for_gpt_attention: Tensor = None):
194+
embed_positions_for_gpt_attention: Tensor = None,
195+
embed_positions_local: Tensor = None,
196+
rotary_inv_freq_local: Tensor = None,
197+
embed_positions_for_gpt_attention_local: Tensor = None):
190198
self.embed_positions = embed_positions
191199
self.rotary_inv_freq = rotary_inv_freq
192200
self.embed_positions_for_gpt_attention = embed_positions_for_gpt_attention
201+
self.embed_positions_local = embed_positions_local
202+
self.rotary_inv_freq_local = rotary_inv_freq_local
203+
self.embed_positions_for_gpt_attention_local = embed_positions_for_gpt_attention_local
193204
return self
194205

195206
def fill_attention_const_params_for_long_rope(
@@ -359,6 +370,7 @@ def __init__(self,
359370
dtype=None,
360371
position_embedding_type=PositionEmbeddingType.learned_absolute,
361372
rotary_embedding_base=10000.0,
373+
rotary_embedding_base_local=1.0,
362374
rotary_embedding_scaling=None,
363375
rotary_embedding_percentage=1.0,
364376
rope_scaling_short_factors=None,
@@ -388,7 +400,8 @@ def __init__(self,
388400
cp_size=1,
389401
cp_rank=0,
390402
max_seqlen_for_logn_scaling=8192,
391-
use_logn_scaling=False):
403+
use_logn_scaling=False,
404+
is_local=False):
392405
super().__init__()
393406

394407
self.local_layer_idx = local_layer_idx
@@ -417,6 +430,7 @@ def __init__(self,
417430
self.cp_group = cp_group
418431
self.cp_size = cp_size
419432
self.cp_rank = cp_rank
433+
self.is_local = is_local
420434

421435
self.num_layers = num_layers
422436
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
@@ -437,6 +451,7 @@ def __init__(self,
437451
self.max_distance = max_distance
438452
self.num_buckets = num_buckets
439453
self.rotary_embedding_base = rotary_embedding_base
454+
self.rotary_embedding_base_local = rotary_embedding_base_local
440455
self.rotary_embedding_scaling = rotary_embedding_scaling
441456
self.rotary_embedding_scale_type = RotaryScalingType.none
442457
self.rotary_embedding_scale = 1.0
@@ -677,6 +692,29 @@ def create_attention_const_params(model_cls, config):
677692
dtype='float32',
678693
is_buffer=True))
679694

695+
rotary_embedding_base_local = getattr(config, 'rope_local_base_freq', None)
696+
if rotary_embedding_base_local is not None:
697+
embed_positions_local = RopeEmbeddingUtils.create_sinusoidal_positions(
698+
max_position_embeddings,
699+
rotary_embedding_dim,
700+
)
701+
rotary_inv_freq_local, embed_positions_for_gpt_attention_local = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
702+
max_position_embeddings, rotary_embedding_dim,
703+
rotary_embedding_base_local, rotary_embedding_scale,
704+
rotary_embedding_scale_type, rotary_embedding_scaling)
705+
model_cls.register_parameter(
706+
'embed_positions_local',
707+
Parameter(embed_positions_local, dtype='float32', is_buffer=True))
708+
model_cls.register_parameter(
709+
'rotary_inv_freq_local',
710+
Parameter(rotary_inv_freq_local, dtype='float32', is_buffer=True))
711+
model_cls.register_parameter(
712+
'embed_positions_for_gpt_attention_local',
713+
Parameter(embed_positions_for_gpt_attention_local,
714+
dtype='float32',
715+
is_buffer=True))
716+
717+
680718
@staticmethod
681719
def fill_attention_params(model_cls, attention_params):
682720
if model_cls.position_embedding_type.is_rope():
@@ -695,7 +733,10 @@ def fill_attention_params(model_cls, attention_params):
695733
return attention_params.fill_attention_const_params_for_rope(
696734
model_cls.embed_positions.value,
697735
model_cls.rotary_inv_freq.value,
698-
model_cls.embed_positions_for_gpt_attention.value)
736+
model_cls.embed_positions_for_gpt_attention.value,
737+
model_cls.embed_positions_local.value,
738+
model_cls.rotary_inv_freq_local.value,
739+
model_cls.embed_positions_for_gpt_attention_local.value)
699740
# Fill nothing.
700741
return attention_params
701742

@@ -1020,6 +1061,9 @@ def compute_cross_kv(encoder_output):
10201061
# Rotary cos/sin cache.
10211062
rotary_cos_sin = getattr(attention_params,
10221063
"embed_positions_for_gpt_attention", None)
1064+
rotary_inv_freq_local = getattr(attention_params, "rotary_inv_freq_local", None)
1065+
rotary_cos_sin_local = getattr(attention_params,
1066+
"embed_positions_for_gpt_attention_local", None)
10231067

10241068
long_rope_rotary_inv_freq = getattr(attention_params,
10251069
"long_rope_rotary_inv_freq",
@@ -1037,6 +1081,9 @@ def compute_cross_kv(encoder_output):
10371081
assert (rotary_inv_freq is not None) and (
10381082
rotary_cos_sin is not None
10391083
), "rotary_inv_freq and embed_positions_for_gpt_attention must be provided."
1084+
assert (rotary_inv_freq_local is not None) and (
1085+
rotary_cos_sin_local is not None
1086+
), "rotary_inv_freq_local and embed_positions_for_gpt_attention_local must be provided."
10401087
if self.position_embedding_type == PositionEmbeddingType.long_rope:
10411088
assert long_rope_rotary_inv_freq is not None
10421089
assert long_rope_rotary_cos_sin is not None
@@ -1062,7 +1109,7 @@ def compute_cross_kv(encoder_output):
10621109
hidden_size_per_head=self.attention_head_size,
10631110
q_scaling=self.q_scaling,
10641111
rotary_embedding_dim=self.rotary_embedding_dim,
1065-
rotary_embedding_base=self.rotary_embedding_base,
1112+
rotary_embedding_base=self.rotary_embedding_base if not self.is_local else self.rotary_embedding_base_local,
10661113
rotary_embedding_scale_type=self.rotary_embedding_scale_type,
10671114
rotary_embedding_short_m_scale=attention_params.short_mscale,
10681115
rotary_embedding_long_m_scale=attention_params.long_mscale,
@@ -1071,8 +1118,8 @@ def compute_cross_kv(encoder_output):
10711118
rotary_embedding_original_max_positions=self.
10721119
original_max_position_embeddings,
10731120
position_embedding_type=self.position_embedding_type,
1074-
rotary_inv_freq=rotary_inv_freq,
1075-
rotary_cos_sin=rotary_cos_sin,
1121+
rotary_inv_freq=rotary_inv_freq if not self.is_local else rotary_inv_freq_local,
1122+
rotary_cos_sin=rotary_cos_sin if not self.is_local else rotary_cos_sin_local,
10761123
kv_orig_quant_scale=kv_orig_quant_scale,
10771124
kv_quant_orig_scale=kv_quant_orig_scale,
10781125
attention_output_orig_quant_scale=self.

tensorrt_llm/models/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder
3434
from .falcon.config import FalconConfig
3535
from .falcon.model import FalconForCausalLM, FalconModel
36-
from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig
36+
from .gemma.config import GEMMA3_ARCHITECTURE, GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig
3737
from .gemma.model import GemmaForCausalLM
3838
from .gpt.config import GPTConfig
3939
from .gpt.model import GPTForCausalLM, GPTModel
@@ -183,6 +183,7 @@
183183
'SkyworkForCausalLM': LLaMAForCausalLM,
184184
GEMMA_ARCHITECTURE: GemmaForCausalLM,
185185
GEMMA2_ARCHITECTURE: GemmaForCausalLM,
186+
GEMMA3_ARCHITECTURE: GemmaForCausalLM,
186187
'QWenLMHeadModel': QWenForCausalLM,
187188
'QWenForCausalLM': QWenForCausalLM,
188189
'Qwen2ForCausalLM': QWenForCausalLM,

tensorrt_llm/models/gemma/config.py

+31-4
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from typing import TYPE_CHECKING, Optional, Union
15+
from typing import TYPE_CHECKING, Optional, Union, List
1616

1717
from tensorrt_llm.functional import PositionEmbeddingType
1818
from tensorrt_llm.logger import logger
1919
from tensorrt_llm.mapping import Mapping
2020
from tensorrt_llm.models.convert_utils import infer_dtype
21-
from tensorrt_llm.models.modeling_utils import (Gemma2ConfigGroup,
21+
from tensorrt_llm.models.modeling_utils import (Gemma2ConfigGroup, Gemma3ConfigGroup,
2222
PretrainedConfig, QuantConfig)
2323

2424
if TYPE_CHECKING:
@@ -30,6 +30,7 @@
3030

3131
GEMMA_ARCHITECTURE = "GemmaForCausalLM"
3232
GEMMA2_ARCHITECTURE = "Gemma2ForCausalLM"
33+
GEMMA3_ARCHITECTURE = "Gemma3ForCausalLM"
3334

3435

3536
class GemmaConfig(PretrainedConfig):
@@ -48,6 +49,9 @@ def __init__(
4849
final_logit_softcapping: Optional[float] = None,
4950
attn_logit_softcapping: Optional[float] = None,
5051
mapping: Optional[Union[Mapping, dict]] = None,
52+
sliding_window_pattern: int = None,
53+
rope_local_base_freq: int = None,
54+
sliding_window: int = None,
5155
**kwargs,
5256
):
5357
use_parallel_embedding = False
@@ -85,17 +89,26 @@ def __init__(
8589
self.query_pre_attn_scalar = query_pre_attn_scalar
8690
self.final_logit_softcapping = final_logit_softcapping
8791
self.attn_logit_softcapping = attn_logit_softcapping
92+
elif self.is_gemma_3:
93+
self.inter_layernorms = True
94+
assert query_pre_attn_scalar is not None, "Gemma3 models must configure `query_pre_attn_scalar`"
95+
self.query_pre_attn_scalar = query_pre_attn_scalar
96+
self.final_logit_softcapping = final_logit_softcapping
97+
self.sliding_window_pattern = sliding_window_pattern
98+
self.rope_local_base_freq = rope_local_base_freq
99+
self.sliding_window = sliding_window
88100

89101
GEMMA_ADDED_FIELDS = {
90102
"rotary_base", "rotary_scaling", "attn_bias", "mlp_bias",
91103
"inter_layernorms"
92104
}
93105
GEMMA2_ADDED_FIELDS = Gemma2ConfigGroup.keys()
106+
GEMMA3_ADDED_FIELDS = Gemma3ConfigGroup.keys()
94107
VERBATIM = {
95108
"num_hidden_layers", "num_attention_heads", "hidden_size",
96109
"intermediate_size", "vocab_size", "max_position_embeddings",
97110
"hidden_act", "use_parallel_embedding"
98-
} | GEMMA2_ADDED_FIELDS
111+
} | GEMMA2_ADDED_FIELDS | GEMMA3_ADDED_FIELDS
99112

100113
@property
101114
def is_gemma_2(self) -> bool:
@@ -106,6 +119,15 @@ def gemma2_config(self):
106119
return self.get_config_group(Gemma2ConfigGroup)
107120
return None
108121

122+
@property
123+
def is_gemma_3(self) -> bool:
124+
return self.architecture == GEMMA3_ARCHITECTURE
125+
126+
def gemma3_config(self):
127+
if self.is_gemma_3:
128+
return self.get_config_group(Gemma3ConfigGroup)
129+
return None
130+
109131
def to_dict(self):
110132
"""Serialize the fields added in GemmaConfig"""
111133

@@ -118,7 +140,11 @@ def to_dict(self):
118140
**({
119141
f: getattr(self, f)
120142
for f in self.GEMMA2_ADDED_FIELDS
121-
} if self.is_gemma_2 else {})
143+
} if self.is_gemma_2 else {}),
144+
**({
145+
f: getattr(self, f)
146+
for f in self.GEMMA3_ADDED_FIELDS
147+
} if self.is_gemma_3 else {})
122148
}
123149

124150
@classmethod
@@ -148,6 +174,7 @@ def from_hugging_face(
148174
norm_epsilon=hf_config.rms_norm_eps,
149175
num_key_value_heads=getattr(hf_config, "num_key_value_heads",
150176
hf_config.num_attention_heads),
177+
rotary_base=getattr(hf_config, "rope_theta", 10000.0),
151178
rotary_scaling=getattr(hf_config, "rotary_scaling", None),
152179
quantization=quant_config,
153180
mapping=mapping,

tensorrt_llm/models/gemma/convert.py

+6
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,10 @@ def rename_to_trt_llm(self, name: str) -> Optional[str]:
317317
None), # merged with above
318318
(r"model.layers.(\d+).self_attn.o_proj.weight",
319319
r"layers.\1.attention.dense.weight"),
320+
(r"model.layers.(\d+).self_attn.q_norm.weight",
321+
r"layers.\1.attention.q_layernorm.weight"),
322+
(r"model.layers.(\d+).self_attn.k_norm.weight",
323+
r"layers.\1.attention.k_layernorm.weight"),
320324
(r"model.layers.(\d+).mlp.gate_proj.weight",
321325
r"layers.\1.mlp.fc.weight"),
322326
(r"model.layers.(\d+).mlp.up_proj.weight",
@@ -795,6 +799,8 @@ def load_gemma_weights(
795799
"pre_feedforward_layernorm",
796800
"post_feedforward_layernorm",
797801
"model.norm.weight",
802+
"q_norm.weight",
803+
"k_norm.weight",
798804
)):
799805
param = param + 1.0 # upcasted to float32 in case of bfloat16
800806
add_trt_llm_weight(weights, trt_llm_name, param,

tensorrt_llm/models/gemma/model.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from ..._common import default_net
2525
from ..._utils import pad_vocab_size
2626
from ...functional import (AllReduceFusionOp, AllReduceParams, Tensor, cast,
27-
recv, send)
27+
recv, send, LayerNormType)
2828
from ...layers import (Attention, AttentionMaskType, AttentionParams,
2929
ColumnLinear, Embedding, GatedMLP, KeyValueCacheParams,
3030
LoraParams, PositionEmbeddingType, RmsNorm)
@@ -56,32 +56,48 @@ def __init__(self, config: GemmaConfig, layer_idx: int):
5656

5757
q_scaling = 1.0
5858
max_attn_value = 0.0
59+
qk_layernorm = False
60+
is_sliding = False
61+
rotary_base = config.rotary_base
62+
rotary_base_local = None
5963

6064
gemma2_config = config.gemma2_config()
65+
gemma3_config = config.gemma3_config()
6166
if gemma2_config:
6267
q_scaling = math.sqrt(
6368
gemma2_config.query_pre_attn_scalar) / math.sqrt(
6469
config.head_size)
6570
max_attn_value = config.attn_logit_softcapping or 0.0
71+
elif gemma3_config:
72+
qk_layernorm = True
73+
q_scaling = math.sqrt(
74+
gemma3_config.query_pre_attn_scalar) / math.sqrt(
75+
config.head_size)
76+
is_sliding = bool((layer_idx + 1) % gemma3_config.sliding_window_pattern)
77+
rotary_base_local = config.rope_local_base_freq
6678

6779
self.attention = Attention(
6880
local_layer_idx=self.local_layer_idx,
6981
hidden_size=config.hidden_size,
7082
num_attention_heads=config.num_attention_heads,
7183
num_kv_heads=config.num_key_value_heads,
7284
attention_head_size=config.head_size,
85+
qk_layernorm=qk_layernorm,
86+
layernorm_type=LayerNormType.RmsNorm,
7387
max_position_embeddings=config.max_position_embeddings,
7488
dtype=config.dtype,
7589
attention_mask_type=AttentionMaskType.causal,
7690
bias=config.attn_bias,
7791
position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
78-
rotary_embedding_base=config.rotary_base,
92+
rotary_embedding_base=rotary_base,
93+
rotary_embedding_base_local=rotary_base_local,
7994
rotary_embedding_scaling=config.rotary_scaling,
8095
tp_group=config.mapping.tp_group,
8196
tp_size=config.mapping.tp_size,
8297
quant_mode=config.quant_mode,
8398
q_scaling=q_scaling,
8499
max_attn_value=max_attn_value,
100+
is_local=is_sliding,
85101
)
86102

87103
mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size
@@ -223,8 +239,7 @@ def forward(self,
223239

224240
if self.mapping.is_first_pp_rank():
225241
hidden_states = self.vocab_embedding(input_ids, *ptuning_args)
226-
hidden_states = cast(hidden_states * math.sqrt(self.hidden_size),
227-
hidden_states.dtype)
242+
hidden_states = cast(hidden_states * math.sqrt(self.hidden_size), hidden_states.dtype)
228243
else:
229244
hidden_states = recv(hidden_states, self.mapping.prev_pp_rank())
230245
hidden_states = self.layers.forward(

0 commit comments

Comments
 (0)