Skip to content

Commit 1d1fadd

Browse files
committed
feat: Add Gemma3 text-only model support
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 943218b commit 1d1fadd

File tree

11 files changed

+219
-39
lines changed

11 files changed

+219
-39
lines changed

examples/gemma/README.md

+48
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
- [Run inference under INT8 KV caches for keras checkpoint](#run-inference-under-int8-kv-caches-for-keras-checkpoint)
2525
- [Run Gemma 2](#run-gemma-2)
2626
- [Run inference under bfloat16 for torch checkpoint](#run-inference-under-bfloat16-for-torch-checkpoint-1)
27+
- [Run Gemma 3](#run-gemma-3)
28+
- [Run inference under bfloat16 for HF checkpoint](#run-inference-under-bfloat16-for-hf-checkpoint-1)
2729
- [Run Modelopt Quantization](#run-modelopt-quantization)
2830
- [Requirements](#requirements)
2931
- [Quantize Checkpoints](#quantize-checkpoints)
@@ -628,6 +630,52 @@ Average accuracy 0.697 - other (business, health, misc.)
628630
Average accuracy: 0.630
629631
```
630632

633+
### Run Gemma 3
634+
635+
Gemma 3's text generation model from HuggingFace is supported. Gemma3 1B model interleaves 5 local layers between each global layer. While local layers use sliding window attention with a short span of 512 tokens, global layers attend to the long context. TRTLLM support layerwise sliding-window attention and the sliding window size for each layer could be passed in using the `--max_attention_window_size` parameter at runtime. If a subpattern is provided, TRTLLM can extrapolate the complete pattern and the extrapolation logic is printed to terminal.
636+
637+
#### Run inference under bfloat16 for HF checkpoint
638+
```bash
639+
git clone https://huggingface.co/google/gemma-3-1b-it
640+
641+
CKPT_PATH=gemma-3-1b-it/
642+
UNIFIED_CKPT_PATH=/tmp/checkpoints/tmp_1b_it_tensorrt_llm/bf16/tp1/
643+
ENGINE_PATH=/tmp/gemma3/1b/bf16/1-gpu/
644+
VOCAB_FILE_PATH=gemma-3-1b-it/tokenizer.model
645+
646+
python3 ./examples/gemma/convert_checkpoint.py \
647+
--ckpt-type hf \
648+
--model-dir ${CKPT_PATH} \
649+
--dtype bfloat16 \
650+
--world-size 1 \
651+
--output-model-dir ${UNIFIED_CKPT_PATH}
652+
653+
trtllm-build --checkpoint_dir ${UNIFIED_CKPT_PATH} \
654+
--gemm_plugin auto \
655+
--max_batch_size 8 \
656+
--max_input_len 3000 \
657+
--max_seq_len 3100 \
658+
--output_dir ${ENGINE_PATH}
659+
660+
python3 ./examples/summarize.py --test_trt_llm \
661+
--vocab_file ${VOCAB_FILE_PATH} \
662+
--engine_dir ${ENGINE_PATH} \
663+
--batch_size 1 \
664+
--max_ite 5 \
665+
--max_attention_window_size 512 512 512 512 512 3100
666+
667+
...
668+
[TensorRT-LLM][INFO] TRTGptModel mMaxAttentionWindowSize: (512, 512, 512, 512, 512, 3100) * 4 + (512, 512)
669+
...
670+
[04/09/2025-18:28:26] [TRT-LLM] [I] TensorRT-LLM (total latency: 1.6197962760925293 sec)
671+
[04/09/2025-18:28:26] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 475)
672+
[04/09/2025-18:28:26] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 293.2467539349165)
673+
[04/09/2025-18:28:26] [TRT-LLM] [I] TensorRT-LLM beam 0 result
674+
[04/09/2025-18:28:26] [TRT-LLM] [I] rouge1: 22.780314381954003
675+
[04/09/2025-18:28:26] [TRT-LLM] [I] rouge2: 4.331099231480823
676+
[04/09/2025-18:28:26] [TRT-LLM] [I] rougeL: 15.26751867562475
677+
[04/09/2025-18:28:26] [TRT-LLM] [I] rougeLsum: 20.14696930976001
678+
```
631679

632680
### Run Modelopt Quantization
633681

tensorrt_llm/layers/attention.py

+78-26
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,13 @@ def __init__(self,
175175
self.embed_positions = None
176176
self.rotary_inv_freq = None
177177
self.embed_positions_for_gpt_attention = None
178+
179+
# auxiliary params to support models with non-homegeneous attn layers requiring
180+
# a different set of rope params. e.g. Gemma3.
181+
self.embed_positions_local = None
182+
self.rotary_inv_freq_local = None
183+
self.embed_positions_for_gpt_attention_local = None
184+
178185
# long rope const parameters
179186
self.long_rope_embed_positions = None
180187
self.long_rope_rotary_inv_freq = None
@@ -186,10 +193,16 @@ def fill_attention_const_params_for_rope(
186193
self,
187194
embed_positions: Tensor = None,
188195
rotary_inv_freq: Tensor = None,
189-
embed_positions_for_gpt_attention: Tensor = None):
196+
embed_positions_for_gpt_attention: Tensor = None,
197+
embed_positions_local: Tensor = None,
198+
rotary_inv_freq_local: Tensor = None,
199+
embed_positions_for_gpt_attention_local: Tensor = None):
190200
self.embed_positions = embed_positions
191201
self.rotary_inv_freq = rotary_inv_freq
192202
self.embed_positions_for_gpt_attention = embed_positions_for_gpt_attention
203+
self.embed_positions_local = embed_positions_local
204+
self.rotary_inv_freq_local = rotary_inv_freq_local
205+
self.embed_positions_for_gpt_attention_local = embed_positions_for_gpt_attention_local
193206
return self
194207

195208
def fill_attention_const_params_for_long_rope(
@@ -359,6 +372,7 @@ def __init__(self,
359372
dtype=None,
360373
position_embedding_type=PositionEmbeddingType.learned_absolute,
361374
rotary_embedding_base=10000.0,
375+
rotary_embedding_base_local=1.0,
362376
rotary_embedding_scaling=None,
363377
rotary_embedding_percentage=1.0,
364378
rope_scaling_short_factors=None,
@@ -388,7 +402,8 @@ def __init__(self,
388402
cp_size=1,
389403
cp_rank=0,
390404
max_seqlen_for_logn_scaling=8192,
391-
use_logn_scaling=False):
405+
use_logn_scaling=False,
406+
is_local=False):
392407
super().__init__()
393408

394409
self.local_layer_idx = local_layer_idx
@@ -417,6 +432,7 @@ def __init__(self,
417432
self.cp_group = cp_group
418433
self.cp_size = cp_size
419434
self.cp_rank = cp_rank
435+
self.is_local = is_local
420436

421437
self.num_layers = num_layers
422438
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
@@ -437,6 +453,7 @@ def __init__(self,
437453
self.max_distance = max_distance
438454
self.num_buckets = num_buckets
439455
self.rotary_embedding_base = rotary_embedding_base
456+
self.rotary_embedding_base_local = rotary_embedding_base_local
440457
self.rotary_embedding_scaling = rotary_embedding_scaling
441458
self.rotary_embedding_scale_type = RotaryScalingType.none
442459
self.rotary_embedding_scale = 1.0
@@ -656,26 +673,45 @@ def create_attention_const_params(model_cls, config):
656673
model_cls.short_mscale = short_mscale
657674
model_cls.long_mscale = long_mscale
658675
else:
659-
# Rotary const weights.
660-
embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions(
661-
max_position_embeddings,
662-
rotary_embedding_dim,
663-
)
664-
rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
665-
max_position_embeddings, rotary_embedding_dim,
666-
rotary_embedding_base, rotary_embedding_scale,
667-
rotary_embedding_scale_type, rotary_embedding_scaling)
668-
model_cls.register_parameter(
669-
'embed_positions',
670-
Parameter(embed_positions, dtype='float32', is_buffer=True))
671-
model_cls.register_parameter(
672-
'rotary_inv_freq',
673-
Parameter(rotary_inv_freq, dtype='float32', is_buffer=True))
674-
model_cls.register_parameter(
675-
'embed_positions_for_gpt_attention',
676-
Parameter(embed_positions_for_gpt_attention,
677-
dtype='float32',
678-
is_buffer=True))
676+
677+
def register_rope_params(rotary_base, names_to_register):
678+
# Rotary const weights.
679+
embed_positions = RopeEmbeddingUtils.create_sinusoidal_positions(
680+
max_position_embeddings,
681+
rotary_embedding_dim,
682+
)
683+
rotary_inv_freq, embed_positions_for_gpt_attention = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
684+
max_position_embeddings, rotary_embedding_dim, rotary_base,
685+
rotary_embedding_scale, rotary_embedding_scale_type,
686+
rotary_embedding_scaling)
687+
model_cls.register_parameter(
688+
names_to_register[0],
689+
Parameter(embed_positions, dtype='float32', is_buffer=True))
690+
model_cls.register_parameter(
691+
names_to_register[1],
692+
Parameter(rotary_inv_freq, dtype='float32', is_buffer=True))
693+
model_cls.register_parameter(
694+
names_to_register[2],
695+
Parameter(embed_positions_for_gpt_attention,
696+
dtype='float32',
697+
is_buffer=True))
698+
699+
register_rope_params(rotary_base=rotary_embedding_base,
700+
names_to_register=[
701+
'embed_positions', 'rotary_inv_freq',
702+
'embed_positions_for_gpt_attention'
703+
])
704+
705+
# For models with non-homegeneous attention layers requiring a second set of rope params. e.g. Gemma3.
706+
rotary_embedding_base_local = getattr(config,
707+
'rope_local_base_freq', None)
708+
if rotary_embedding_base_local is not None:
709+
register_rope_params(
710+
rotary_base=rotary_embedding_base_local,
711+
names_to_register=[
712+
'embed_positions_local', 'rotary_inv_freq_local',
713+
'embed_positions_for_gpt_attention_local'
714+
])
679715

680716
@staticmethod
681717
def fill_attention_params(model_cls, attention_params):
@@ -695,7 +731,15 @@ def fill_attention_params(model_cls, attention_params):
695731
return attention_params.fill_attention_const_params_for_rope(
696732
model_cls.embed_positions.value,
697733
model_cls.rotary_inv_freq.value,
698-
model_cls.embed_positions_for_gpt_attention.value)
734+
model_cls.embed_positions_for_gpt_attention.value,
735+
model_cls.embed_positions_local.value if hasattr(
736+
model_cls, "embed_positions_local") else None,
737+
model_cls.rotary_inv_freq_local.value if hasattr(
738+
model_cls, "rotary_inv_freq_local") else None,
739+
model_cls.embed_positions_for_gpt_attention_local.value
740+
if hasattr(
741+
model_cls,
742+
"embed_positions_for_gpt_attention_local") else None)
699743
# Fill nothing.
700744
return attention_params
701745

@@ -1020,6 +1064,11 @@ def compute_cross_kv(encoder_output):
10201064
# Rotary cos/sin cache.
10211065
rotary_cos_sin = getattr(attention_params,
10221066
"embed_positions_for_gpt_attention", None)
1067+
rotary_inv_freq_local = getattr(attention_params,
1068+
"rotary_inv_freq_local", None)
1069+
rotary_cos_sin_local = getattr(
1070+
attention_params, "embed_positions_for_gpt_attention_local",
1071+
None)
10231072

10241073
long_rope_rotary_inv_freq = getattr(attention_params,
10251074
"long_rope_rotary_inv_freq",
@@ -1062,7 +1111,8 @@ def compute_cross_kv(encoder_output):
10621111
hidden_size_per_head=self.attention_head_size,
10631112
q_scaling=self.q_scaling,
10641113
rotary_embedding_dim=self.rotary_embedding_dim,
1065-
rotary_embedding_base=self.rotary_embedding_base,
1114+
rotary_embedding_base=self.rotary_embedding_base
1115+
if not self.is_local else self.rotary_embedding_base_local,
10661116
rotary_embedding_scale_type=self.rotary_embedding_scale_type,
10671117
rotary_embedding_short_m_scale=attention_params.short_mscale,
10681118
rotary_embedding_long_m_scale=attention_params.long_mscale,
@@ -1071,8 +1121,10 @@ def compute_cross_kv(encoder_output):
10711121
rotary_embedding_original_max_positions=self.
10721122
original_max_position_embeddings,
10731123
position_embedding_type=self.position_embedding_type,
1074-
rotary_inv_freq=rotary_inv_freq,
1075-
rotary_cos_sin=rotary_cos_sin,
1124+
rotary_inv_freq=rotary_inv_freq
1125+
if not self.is_local else rotary_inv_freq_local,
1126+
rotary_cos_sin=rotary_cos_sin
1127+
if not self.is_local else rotary_cos_sin_local,
10761128
kv_orig_quant_scale=kv_orig_quant_scale,
10771129
kv_quant_orig_scale=kv_quant_orig_scale,
10781130
attention_output_orig_quant_scale=self.

tensorrt_llm/models/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
from .enc_dec.model import DecoderModel, EncoderModel, WhisperEncoder
3434
from .falcon.config import FalconConfig
3535
from .falcon.model import FalconForCausalLM, FalconModel
36-
from .gemma.config import GEMMA2_ARCHITECTURE, GEMMA_ARCHITECTURE, GemmaConfig
36+
from .gemma.config import (GEMMA2_ARCHITECTURE, GEMMA3_ARCHITECTURE,
37+
GEMMA_ARCHITECTURE, GemmaConfig)
3738
from .gemma.model import GemmaForCausalLM
3839
from .gpt.config import GPTConfig
3940
from .gpt.model import GPTForCausalLM, GPTModel
@@ -183,6 +184,7 @@
183184
'SkyworkForCausalLM': LLaMAForCausalLM,
184185
GEMMA_ARCHITECTURE: GemmaForCausalLM,
185186
GEMMA2_ARCHITECTURE: GemmaForCausalLM,
187+
GEMMA3_ARCHITECTURE: GemmaForCausalLM,
186188
'QWenLMHeadModel': QWenForCausalLM,
187189
'QWenForCausalLM': QWenForCausalLM,
188190
'Qwen2ForCausalLM': QWenForCausalLM,

tensorrt_llm/models/gemma/config.py

+30-5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from tensorrt_llm.mapping import Mapping
2020
from tensorrt_llm.models.convert_utils import infer_dtype
2121
from tensorrt_llm.models.modeling_utils import (Gemma2ConfigGroup,
22+
Gemma3ConfigGroup,
2223
PretrainedConfig, QuantConfig)
2324

2425
if TYPE_CHECKING:
@@ -30,6 +31,7 @@
3031

3132
GEMMA_ARCHITECTURE = "GemmaForCausalLM"
3233
GEMMA2_ARCHITECTURE = "Gemma2ForCausalLM"
34+
GEMMA3_ARCHITECTURE = "Gemma3ForCausalLM"
3335

3436

3537
class GemmaConfig(PretrainedConfig):
@@ -48,6 +50,9 @@ def __init__(
4850
final_logit_softcapping: Optional[float] = None,
4951
attn_logit_softcapping: Optional[float] = None,
5052
mapping: Optional[Union[Mapping, dict]] = None,
53+
sliding_window_pattern: int = None,
54+
rope_local_base_freq: int = None,
55+
sliding_window: int = None,
5156
**kwargs,
5257
):
5358
use_parallel_embedding = False
@@ -79,23 +84,29 @@ def __init__(
7984
self.mlp_bias = mlp_bias
8085

8186
self.inter_layernorms = False
82-
if self.is_gemma_2:
87+
if self.is_gemma_2 or self.is_gemma_3:
8388
self.inter_layernorms = True
84-
assert query_pre_attn_scalar is not None, "Gemma2 models must configure `query_pre_attn_scalar`"
89+
assert query_pre_attn_scalar is not None, "Gemma2 / Gemma3 models must configure `query_pre_attn_scalar`"
8590
self.query_pre_attn_scalar = query_pre_attn_scalar
8691
self.final_logit_softcapping = final_logit_softcapping
87-
self.attn_logit_softcapping = attn_logit_softcapping
92+
if self.is_gemma_2:
93+
self.attn_logit_softcapping = attn_logit_softcapping
94+
if self.is_gemma_3:
95+
self.sliding_window_pattern = sliding_window_pattern
96+
self.rope_local_base_freq = rope_local_base_freq
97+
self.sliding_window = sliding_window
8898

8999
GEMMA_ADDED_FIELDS = {
90100
"rotary_base", "rotary_scaling", "attn_bias", "mlp_bias",
91101
"inter_layernorms"
92102
}
93103
GEMMA2_ADDED_FIELDS = Gemma2ConfigGroup.keys()
104+
GEMMA3_ADDED_FIELDS = Gemma3ConfigGroup.keys()
94105
VERBATIM = {
95106
"num_hidden_layers", "num_attention_heads", "hidden_size",
96107
"intermediate_size", "vocab_size", "max_position_embeddings",
97108
"hidden_act", "use_parallel_embedding"
98-
} | GEMMA2_ADDED_FIELDS
109+
} | GEMMA2_ADDED_FIELDS | GEMMA3_ADDED_FIELDS
99110

100111
@property
101112
def is_gemma_2(self) -> bool:
@@ -106,6 +117,15 @@ def gemma2_config(self):
106117
return self.get_config_group(Gemma2ConfigGroup)
107118
return None
108119

120+
@property
121+
def is_gemma_3(self) -> bool:
122+
return self.architecture == GEMMA3_ARCHITECTURE
123+
124+
def gemma3_config(self):
125+
if self.is_gemma_3:
126+
return self.get_config_group(Gemma3ConfigGroup)
127+
return None
128+
109129
def to_dict(self):
110130
"""Serialize the fields added in GemmaConfig"""
111131

@@ -118,7 +138,11 @@ def to_dict(self):
118138
**({
119139
f: getattr(self, f)
120140
for f in self.GEMMA2_ADDED_FIELDS
121-
} if self.is_gemma_2 else {})
141+
} if self.is_gemma_2 else {}),
142+
**({
143+
f: getattr(self, f)
144+
for f in self.GEMMA3_ADDED_FIELDS
145+
} if self.is_gemma_3 else {})
122146
}
123147

124148
@classmethod
@@ -148,6 +172,7 @@ def from_hugging_face(
148172
norm_epsilon=hf_config.rms_norm_eps,
149173
num_key_value_heads=getattr(hf_config, "num_key_value_heads",
150174
hf_config.num_attention_heads),
175+
rotary_base=getattr(hf_config, "rope_theta", 10000.0),
151176
rotary_scaling=getattr(hf_config, "rotary_scaling", None),
152177
quantization=quant_config,
153178
mapping=mapping,

tensorrt_llm/models/gemma/convert.py

+6
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,10 @@ def rename_to_trt_llm(self, name: str) -> Optional[str]:
317317
None), # merged with above
318318
(r"model.layers.(\d+).self_attn.o_proj.weight",
319319
r"layers.\1.attention.dense.weight"),
320+
(r"model.layers.(\d+).self_attn.q_norm.weight",
321+
r"layers.\1.attention.q_layernorm.weight"),
322+
(r"model.layers.(\d+).self_attn.k_norm.weight",
323+
r"layers.\1.attention.k_layernorm.weight"),
320324
(r"model.layers.(\d+).mlp.gate_proj.weight",
321325
r"layers.\1.mlp.fc.weight"),
322326
(r"model.layers.(\d+).mlp.up_proj.weight",
@@ -795,6 +799,8 @@ def load_gemma_weights(
795799
"pre_feedforward_layernorm",
796800
"post_feedforward_layernorm",
797801
"model.norm.weight",
802+
"q_norm.weight",
803+
"k_norm.weight",
798804
)):
799805
param = param + 1.0 # upcasted to float32 in case of bfloat16
800806
add_trt_llm_weight(weights, trt_llm_name, param,

0 commit comments

Comments
 (0)