Skip to content

[ESM] Add support for sdpa. #34954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Conversation

wzf03
Copy link

@wzf03 wzf03 commented Nov 27, 2024

What does this PR do?

Add support for SDPA (scaled dot product attention) for ESM. More context in #28802 (And this pr mainly reused the code from this pr as the ESM is Bert-based model) and #28005 .

This is my first time contributing to this project, please point out if there is any mistakes.

And revert a change in #29329 as the dtype-mismatching issue for bitsandbytes is actually caused by the rotary embedding.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker

@Rocketknight1
Copy link
Member

Thank you for this! Rather than skipping the SDPA test, though, can you write even a simple test that uses the SDPA path? It's okay if it can't compare hidden states deeply because of issues in the ESM model, but if it could compare that output logits are similar that'd give us a lot more confidence in the SDPA code!

@wzf03
Copy link
Author

wzf03 commented Nov 28, 2024

Thank you for this! Rather than skipping the SDPA test, though, can you write even a simple test that uses the SDPA path? It's okay if it can't compare hidden states deeply because of issues in the ESM model, but if it could compare that output logits are similar that'd give us a lot more confidence in the SDPA code!

Thanks for your reply, I will add relevant test cases soon.

@wzf03
Copy link
Author

wzf03 commented Nov 28, 2024

Thank you for this! Rather than skipping the SDPA test, though, can you write even a simple test that uses the SDPA path? It's okay if it can't compare hidden states deeply because of issues in the ESM model, but if it could compare that output logits are similar that'd give us a lot more confidence in the SDPA code!

@Rocketknight1 Hello, the sdpa inference tests for ESMFold has been added. Could you please review it?

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks like a good SDPA addition to me! I'll also set up slow tests in a sec.

@Rocketknight1
Copy link
Member

Hi @wzf03, I ran the full test suite for ESM and I'm seeing one or two test failures. Can you see if you can reproduce those locally? They may just be flaky tests, but it might also be caused by changes in this PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@wzf03
Copy link
Author

wzf03 commented Nov 30, 2024

Hi @wzf03, I ran the full test suite for ESM and I'm seeing one or two test failures. Can you see if you can reproduce those locally? They may just be flaky tests, but it might also be caused by changes in this PR.

Hello @Rocketknight1, I found the test failures were due to the device mismatching of the input_ids (on cpu) and the model (on cuda) under bitsandbytes setting. It can be reproduced locally with the current latest master branch of accelerate@29be4788629b772a3b722076e433b5b3b5c85da3. But in may original test environment with accelerate==1.1.1, everything works well.

I will report this to accelerate later.

@wzf03
Copy link
Author

wzf03 commented Dec 2, 2024

Hello @Rocketknight1 , I made a quick fix according to other model's test, the test cases should work normally now.

@Rocketknight1
Copy link
Member

Yes, looks good to me now! cc @ArthurZucker @LysandreJik for core maintainer review

@wzf03
Copy link
Author

wzf03 commented Dec 9, 2024

@ArthurZucker @LysandreJik Hello! Can you please help review this pr?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey super sorry for the delay, waited a bit because #35235 changes the interface! Do you mind updating this PR ? Hope it's not too much of a burden! 😿

@wzf03
Copy link
Author

wzf03 commented Dec 21, 2024

Hey super sorry for the delay, waited a bit because #35235 changes the interface! Do you mind updating this PR ? Hope it's not too much of a burden! 😿

Sure, I will do it soon.

@pstjohn
Copy link
Contributor

pstjohn commented May 13, 2025

Not sure if this is still active, but I have a similar PR in #38023 to add flash attention 2 to ESM

@ArthurZucker
Copy link
Collaborator

The FA2 per was merged, TBH we'd rather have a small refactor to use the new ATTENTION_INTERFACE!

@wzf03 wzf03 force-pushed the esm-sdpa-support branch from f5d9ecc to 7400d4b Compare May 20, 2025 13:33
@wzf03
Copy link
Author

wzf03 commented May 20, 2025

@ArthurZucker @Rocketknight1 Sorry for the late update. I have merged the sdpa support into the new codebase, can you help review this?

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for updating!

Comment on lines +408 to +470
class EsmSdpaSelfAttention(EsmSelfAttention):
def __init__(self, config, position_embedding_type=None):
super().__init__(config, position_embedding_type)
self.attention_dropout_prob = config.attention_probs_dropout_prob
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0")

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
if self.position_embedding_type not in ["absolute", "rotary"] or output_attentions or head_mask is not None:
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
logger.warning_once(
"EsmSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
"non-absolute or non-rotary `position_embedding_type` or `output_attentions=True` or `head_mask`. "
"Falling back to the manual attention implementation, but specifying the manual implementation will "
"be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument "
'`attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states,
attention_mask,
head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)

bsz, tgt_len, _ = hidden_states.size()

query_layer = self.transpose_for_scores(self.query(hidden_states))

# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention
# mask needs to be such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None

current_states = encoder_hidden_states if is_cross_attention else hidden_states
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask

# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]:
key_layer, value_layer = past_key_value
else:
key_layer = self.transpose_for_scores(self.key(current_states))
value_layer = self.transpose_for_scores(self.value(current_states))
if past_key_value is not None and not is_cross_attention:
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)

# Scale the query for rotary embeddings
query_layer = query_layer * self.attention_head_size**-0.5

if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are still missing the spot : we don't need 3 different classes anymore https://github.com/huggingface/transformers/blob/tp-cb/src/transformers/models/llama/modeling_llama.py#L249-L249

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants