-
Notifications
You must be signed in to change notification settings - Fork 29.5k
[ESM
] Add support for sdpa.
#34954
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[ESM
] Add support for sdpa.
#34954
Conversation
Thank you for this! Rather than skipping the SDPA test, though, can you write even a simple test that uses the SDPA path? It's okay if it can't compare hidden states deeply because of issues in the ESM model, but if it could compare that output logits are similar that'd give us a lot more confidence in the SDPA code! |
Thanks for your reply, I will add relevant test cases soon. |
8f7773d
to
996880a
Compare
@Rocketknight1 Hello, the sdpa inference tests for ESMFold has been added. Could you please review it? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall this looks like a good SDPA addition to me! I'll also set up slow tests in a sec.
Hi @wzf03, I ran the full test suite for ESM and I'm seeing one or two test failures. Can you see if you can reproduce those locally? They may just be flaky tests, but it might also be caused by changes in this PR. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Hello @Rocketknight1, I found the test failures were due to the device mismatching of the I will report this to |
Hello @Rocketknight1 , I made a quick fix according to other model's test, the test cases should work normally now. |
Yes, looks good to me now! cc @ArthurZucker @LysandreJik for core maintainer review |
@ArthurZucker @LysandreJik Hello! Can you please help review this pr? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey super sorry for the delay, waited a bit because #35235 changes the interface! Do you mind updating this PR ? Hope it's not too much of a burden! 😿
Sure, I will do it soon. |
Not sure if this is still active, but I have a similar PR in #38023 to add flash attention 2 to ESM |
The FA2 per was merged, TBH we'd rather have a small refactor to use the new |
@ArthurZucker @Rocketknight1 Sorry for the late update. I have merged the sdpa support into the new codebase, can you help review this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for updating!
class EsmSdpaSelfAttention(EsmSelfAttention): | ||
def __init__(self, config, position_embedding_type=None): | ||
super().__init__(config, position_embedding_type) | ||
self.attention_dropout_prob = config.attention_probs_dropout_prob | ||
self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") | ||
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
attention_mask: Optional[torch.FloatTensor] = None, | ||
head_mask: Optional[torch.FloatTensor] = None, | ||
encoder_hidden_states: Optional[torch.FloatTensor] = None, | ||
encoder_attention_mask: Optional[torch.FloatTensor] = None, | ||
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | ||
output_attentions: Optional[bool] = False, | ||
) -> Tuple[torch.Tensor]: | ||
if self.position_embedding_type not in ["absolute", "rotary"] or output_attentions or head_mask is not None: | ||
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. | ||
logger.warning_once( | ||
"EsmSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " | ||
"non-absolute or non-rotary `position_embedding_type` or `output_attentions=True` or `head_mask`. " | ||
"Falling back to the manual attention implementation, but specifying the manual implementation will " | ||
"be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument " | ||
'`attn_implementation="eager"` when loading the model.' | ||
) | ||
return super().forward( | ||
hidden_states, | ||
attention_mask, | ||
head_mask, | ||
encoder_hidden_states, | ||
encoder_attention_mask, | ||
past_key_value, | ||
output_attentions, | ||
) | ||
|
||
bsz, tgt_len, _ = hidden_states.size() | ||
|
||
query_layer = self.transpose_for_scores(self.query(hidden_states)) | ||
|
||
# If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention | ||
# mask needs to be such that the encoder's padding tokens are not attended to. | ||
is_cross_attention = encoder_hidden_states is not None | ||
|
||
current_states = encoder_hidden_states if is_cross_attention else hidden_states | ||
attention_mask = encoder_attention_mask if is_cross_attention else attention_mask | ||
|
||
# Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning | ||
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: | ||
key_layer, value_layer = past_key_value | ||
else: | ||
key_layer = self.transpose_for_scores(self.key(current_states)) | ||
value_layer = self.transpose_for_scores(self.value(current_states)) | ||
if past_key_value is not None and not is_cross_attention: | ||
key_layer = torch.cat([past_key_value[0], key_layer], dim=2) | ||
value_layer = torch.cat([past_key_value[1], value_layer], dim=2) | ||
|
||
# Scale the query for rotary embeddings | ||
query_layer = query_layer * self.attention_head_size**-0.5 | ||
|
||
if self.is_decoder: | ||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. | ||
# Further calls to cross_attention layer can then reuse all cross-attention | ||
# key/value_states (first "if" case) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you are still missing the spot : we don't need 3 different classes anymore https://github.com/huggingface/transformers/blob/tp-cb/src/transformers/models/llama/modeling_llama.py#L249-L249
What does this PR do?
Add support for SDPA (scaled dot product attention) for ESM. More context in #28802 (And this pr mainly reused the code from this pr as the ESM is Bert-based model) and #28005 .
This is my first time contributing to this project, please point out if there is any mistakes.
And revert a change in #29329 as the dtype-mismatching issue for bitsandbytes is actually caused by the rotary embedding.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker