-
Notifications
You must be signed in to change notification settings - Fork 29.5k
feat: add flash_attn 2 to bert #27478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for your PR! In principle this looks great!
Many architecture uses BertAttention
with # Copied from
, therefore all these architectures could benefit from FA-2 for free, however you will need to apply _supports_flash_attn_2 = True
on all these architectures. You need to
1- run make fix-copies
2- on modified architectures add the flag above + copy paste the BertFlashAttention
class on all on them (with modified names).
Would you be happy to address these changes? Otherwise happy to help you!
Thanks a lot for your review and your suggestions @younesbelkada. |
|
Perfect thanks! |
Thanks @younesbelkada , I did it |
Didn't had time to properly look into it, will do it asap! |
Any updates on getting this PR merged? |
Hello there! I'm working on integrating scaled_dot_product_attention to BERT #28802, and there might be some merge conflicts with this change. Mostly, if my changes go through, then we can get rid of most of the downstream dependencies from fix-copies. Let me know if you have any questions. Happy to discuss and/or chat on the best way forward if necessary. |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
This would be a lifesaver for me, I hope merging this is prioritized! cc @younesbelkada |
@loswald - you can see a quick estimate of the speedups in #28802. The pytorch SDPA implementation uses FA2 under the hood (if your hardware supports it). The PR is ready but we're just waiting on the HG team to merge it. |
Feat: Add flash attention option for BERT
Usage:
model = BertModel.from_pretrained('bert-base-uncased', torch_dtype=torch.bfloat16, use_flash_attention_2=True)
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
@ArthurZucker and @younesbelkada