Skip to content

gemma 3 architecture #2880

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Alireza3242 opened this issue Mar 12, 2025 · 6 comments
Open

gemma 3 architecture #2880

Alireza3242 opened this issue Mar 12, 2025 · 6 comments

Comments

@Alireza3242
Copy link

can you add gemma 3 architecture?

@zhaocc1106
Copy link

+1

@artur-pf
Copy link

Would be epic, ollama and llama.cpp implemented it already

@zhaocc1106
Copy link

zhaocc1106 commented Mar 14, 2025

I found follow attention mask type is not support now:
https://github.com/huggingface/transformers/blob/42ebb6c23e61119f769d7c7c067d5b4ae10e4a7f/src/transformers/models/gemma3/modeling_gemma3.py#L1147

# Apply bidirectional mask on images if token type ids are provided
        if token_type_ids is not None and sequence_length != 1:
            token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
            token_type_mask[token_type_ids == 0] = False  # if text token do not change anything
            token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
            causal_mask = causal_mask.clone()
            causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
                token_type_mask, 0.0
            )

Could i set the attention_mask of the prefill stage with executor api? Thanks .

@zhaocc1106
Copy link

zhaocc1106 commented Mar 15, 2025

I found follow attention mask type is not support now: https://github.com/huggingface/transformers/blob/42ebb6c23e61119f769d7c7c067d5b4ae10e4a7f/src/transformers/models/gemma3/modeling_gemma3.py#L1147

# Apply bidirectional mask on images if token type ids are provided
        if token_type_ids is not None and sequence_length != 1:
            token_type_mask = token_type_ids.unsqueeze(1) == token_type_ids.unsqueeze(2)
            token_type_mask[token_type_ids == 0] = False  # if text token do not change anything
            token_type_mask = token_type_mask.unsqueeze(1).to(causal_mask.device, dtype=torch.bool)
            causal_mask = causal_mask.clone()
            causal_mask[:, :, :, :sequence_length] = causal_mask[:, :, :, :sequence_length].masked_fill(
                token_type_mask, 0.0
            )

Could i set the attention_mask of the prefill stage with executor api? Thanks .

I'm blocked here. I have supported gemma3 text llm refer to gemma2. And is normal when only input text without image token. But when input with image(by ptuning embedding), output will be wrong. I finally found the attention mask is different for text token and image token in prefill phase. Such as:

Only text token (A causal mask):

token1 token2 token3 token4 token5
token1 0 -inf -inf -inf -inf
token2 0 0 -inf -inf -inf
token3 0 0 0 -inf -inf
token4 0 0 0 0 -inf
token5 0 0 0 0 0

With image token (Not a pure causal mask):

txt_token1 img_token2 img_token3 img_token4 txt_token5
txt_token1 0 -inf -inf -inf -inf
img_token2 0 0 0 0 -inf
img_token3 0 0 0 0 -inf
img_token4 0 0 0 0 -inf
txt_token5 0 0 0 0 0

Is there any some good method to support it? Thanks very much!

@zhaocc1106
Copy link

zhaocc1106 commented Mar 17, 2025

I also found gemma3 use sliding_window causal attention mask not causal, which result in error output if a long input. Could we support sliding window causal mask? If support, is there description of it? Thanks very much.

Additionally, i found current gpt attention not support AttentionMaskType.sliding_window_causal type:

assert self.attention_mask_type in [

The sliding_window_causal is a important feature for some new llm.

@zhaocc1106
Copy link

I had try support gemm3 text llm (https://github.com/NetEase-Media/grps_trtllm/tree/master/tools/gemma3/tensorrt_llm_mod). But there is one issue: not support kv cache reuse, ref #2912 .
Additionally, can not process image token because not support image token attention mask as #2880 (comment).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants