Skip to content

feat: Support gemma-3-1b-it #3247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 10, 2025
Merged

Conversation

brb-nv
Copy link
Collaborator

@brb-nv brb-nv commented Apr 2, 2025

This MR adds model support for gemma-3-1b-it.

Details:

  • This model has attn layers of two types - one with sliding-window attn with 10k rope base and another with regular attn with 1 million rope base.
  • Most models in TRTLLM seem to be working with attention layers with attn layers of same type. Hence, there's only one set of RoPE params under AttentionParams that is passed to model's forward pass.
  • However, in this case, we need two sets of RoPE params (one for each layer type mentioned above).
  • Changes in this MR are adding one more set of RoPE fields under AttentionParams with a _local suffix to handle the additional layer type.
  • Alternative is to copy quite some functionality from modeling_utils.py (DecoderLayerList and DecoderModelForCausalLM) and carefully orchestrate forward pass by passing a different set of AttentionParams to each layer type. I felt that's quite some code duplication and maintenance overhead.
  • Sliding window requirement is mentioned to the model at runtime using max_attention_window_size with run.py or whatever API is being used.
    [512, 512, 512, 512, 512, 2048, 512, 512, 512, 512, 512, 2048, 512, 512, 512, 512, 512, 2048, 512, 512, 512, 512, 512, 2048, 512, 512]
$ rm -rf gemma3_1b_ckpt/ && python3 ./examples/gemma/convert_checkpoint.py --ckpt-type hf --model-dir ../random/hf_models/gemma-3-1b-it/ --dtype float16 --output-model-dir gemma3_1b_ckpt/
$ rm -rf gemma3_1b_eng/ && trtllm-build --checkpoint_dir gemma3_1b_ckpt/ --output_dir gemma3_1b_eng/trtllm_engine --max_batch_size 256 --max_seq_len 32768 --max_num_tokens 32768 --workers 1 --use_paged_context_fmha disable --gpus_per_node 1 --nccl_plugin auto
$ rm -rf tllm_debug/ && python3 examples/run.py --max_output_len 512 --max_input_length 2048 --input_text 'The main cities in Italy are (Write a blog post)' --engine_dir gemma3_1b_eng/trtllm_engine --tokenizer_dir ../random/hf_models/gemma-3-1b-it/ --end_id 106

@brb-nv brb-nv marked this pull request as draft April 2, 2025 18:30
@brb-nv brb-nv changed the title draft: Gemma feat: Gemma Apr 4, 2025
@brb-nv
Copy link
Collaborator Author

brb-nv commented Apr 4, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #1179 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #1179 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #885 completed with status: 'FAILURE'

@brb-nv brb-nv force-pushed the user/brb/gemma3-on-main branch 4 times, most recently from 1e31450 to 6c4f0df Compare April 9, 2025 04:01
@brb-nv brb-nv changed the title feat: Gemma feat: Support gemma-3-1b-it Apr 9, 2025
@brb-nv brb-nv marked this pull request as ready for review April 9, 2025 04:27
@brb-nv
Copy link
Collaborator Author

brb-nv commented Apr 9, 2025

/bot run

@brb-nv brb-nv requested a review from amukkara April 9, 2025 04:30
@tensorrt-cicd
Copy link
Collaborator

PR_Github #1546 [ run ] triggered by Bot

@brb-nv brb-nv requested a review from schetlur-nv April 9, 2025 04:33
@tensorrt-cicd
Copy link
Collaborator

PR_Github #1546 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #1155 completed with status: 'FAILURE'

@bebilli
Copy link

bebilli commented Apr 9, 2025

Does it support Gemma-3-27B?

@brb-nv brb-nv self-assigned this Apr 9, 2025
@amukkara
Copy link
Collaborator

amukkara commented Apr 9, 2025

@brb-nv can you update examples/gemma/README.md with gemma3 instructions. specifically, clearly show instructions to set up the correct attention window values.

ideally, the attention window values should be part of the model definition so that users do not have to provide these values. we can add that in a follow-up PR.

@brb-nv
Copy link
Collaborator Author

brb-nv commented Apr 9, 2025

Does it support Gemma-3-27B?

Hi, goal is to add support for text generation model first. We'll get to the multimodal models in a follow-up MR.

@brb-nv brb-nv force-pushed the user/brb/gemma3-on-main branch from ad0e30e to 25e1494 Compare April 9, 2025 17:57
@brb-nv brb-nv force-pushed the user/brb/gemma3-on-main branch 3 times, most recently from 2e9df5a to 1ad80ea Compare April 9, 2025 19:06
@brb-nv
Copy link
Collaborator Author

brb-nv commented Apr 9, 2025

@brb-nv can you update examples/gemma/README.md with gemma3 instructions. specifically, clearly show instructions to set up the correct attention window values.

ideally, the attention window values should be part of the model definition so that users do not have to provide these values. we can add that in a follow-up PR.

Updated here, Anurag. 1ad80ea

@brb-nv brb-nv force-pushed the user/brb/gemma3-on-main branch from 0be1885 to f34b971 Compare April 9, 2025 21:02
@brb-nv
Copy link
Collaborator Author

brb-nv commented Apr 9, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #1646 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #1646 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #1231 completed with status: 'FAILURE'

@brb-nv brb-nv force-pushed the user/brb/gemma3-on-main branch 2 times, most recently from 1d1fadd to 7ebad03 Compare April 10, 2025 01:27
@brb-nv
Copy link
Collaborator Author

brb-nv commented Apr 10, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #1670 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #1670 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #1248 completed with status: 'SUCCESS'

@chzblych chzblych force-pushed the user/brb/gemma3-on-main branch from 7ebad03 to decb665 Compare April 10, 2025 04:23
@chzblych
Copy link
Collaborator

/bot reuse-pipeline

@chzblych chzblych enabled auto-merge (squash) April 10, 2025 04:23
Copy link
Collaborator

@chzblych chzblych left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving this MR for merging.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #1699 [ reuse-pipeline ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #1699 [ reuse-pipeline ] completed with state SUCCESS
Reusing PR_Github #1670 for commit decb665

@chzblych chzblych merged commit c59abae into NVIDIA:main Apr 10, 2025
2 checks passed
Superjomn pushed a commit to Superjomn/TensorRT-LLM that referenced this pull request Apr 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants