Skip to content

feat: add flash_attn 2 to bert #27478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed

Conversation

chiennv2000
Copy link

@chiennv2000 chiennv2000 commented Nov 14, 2023

Feat: Add flash attention option for BERT
Usage:
model = BertModel.from_pretrained('bert-base-uncased', torch_dtype=torch.bfloat16, use_flash_attention_2=True)

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

@ArthurZucker and @younesbelkada

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for your PR! In principle this looks great!
Many architecture uses BertAttention with # Copied from, therefore all these architectures could benefit from FA-2 for free, however you will need to apply _supports_flash_attn_2 = True on all these architectures. You need to
1- run make fix-copies
2- on modified architectures add the flag above + copy paste the BertFlashAttention class on all on them (with modified names).
Would you be happy to address these changes? Otherwise happy to help you!

@chiennv2000
Copy link
Author

Thanks a lot for your review and your suggestions @younesbelkada.
But I don't really familiar with make fix-copies command. Can you guide me on how to do that?

@chiennv2000
Copy link
Author

  1. I appreciate your feedback. I'm happy to receive your assistance in implementing these changes.
    If you could help me with other architectures, that would be fantastic. Additionally, I'm open to collaborating on extending this to the Roberta and XLMR model. @younesbelkada

@younesbelkada
Copy link
Contributor

Perfect thanks!
As a first step, can you simply run make fix-copies and push the changes here? Then we'll take it over from there !

@chiennv2000
Copy link
Author

Thanks @younesbelkada , I did it

@huggingface huggingface deleted a comment from github-actions bot Dec 14, 2023
@ArthurZucker
Copy link
Collaborator

cc @younesbelkada

@huggingface huggingface deleted a comment from github-actions bot Jan 8, 2024
@younesbelkada
Copy link
Contributor

Didn't had time to properly look into it, will do it asap!

@kevinhu
Copy link
Contributor

kevinhu commented Jan 31, 2024

Any updates on getting this PR merged?

@hackyon
Copy link
Contributor

hackyon commented Feb 7, 2024

Hello there!

I'm working on integrating scaled_dot_product_attention to BERT #28802, and there might be some merge conflicts with this change. Mostly, if my changes go through, then we can get rid of most of the downstream dependencies from fix-copies.

Let me know if you have any questions. Happy to discuss and/or chat on the best way forward if necessary.

Copy link
Contributor

github-actions bot commented Mar 3, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@loswald
Copy link

loswald commented Apr 21, 2024

This would be a lifesaver for me, I hope merging this is prioritized! cc @younesbelkada
@chiennv2000 what kind of speedups are you observing with this?

@hackyon
Copy link
Contributor

hackyon commented Apr 22, 2024

This would be a lifesaver for me, I hope merging this is prioritized! cc @younesbelkada @chiennv2000 what kind of speedups are you observing with this?

@loswald - you can see a quick estimate of the speedups in #28802. The pytorch SDPA implementation uses FA2 under the hood (if your hardware supports it). The PR is ready but we're just waiting on the HG team to merge it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants