Skip to content

Add PT compileable support for flash_attn_with_kvcache #1592

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jataylo
Copy link

@jataylo jataylo commented Apr 14, 2025

Continues #1139 adding custom op for flash_attn_with_kvcache.

On a transformers model this improves perf by >2x by avoiding graph breaks. There is a gotcha here, with this implementation an error is thrown in PyTorch 2.6 in user code when reshaping FA output:

torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
RuntimeError: <weakref at 0x7f10e00494e0; to 'torch.storage.UntypedStorage' at 0x7f10e0049400>

This is not an issue for PyTorch 2.7, so I had to introduce conditionalisation to workaround this by returning clone of the output tensors only for PT versions earlier than 2.7 and when compile is being used.

@jataylo
Copy link
Author

jataylo commented Apr 16, 2025

@tridao alternatively if preferred, instead of conditionalising the clone for < PT 2.7, we could just disable compile-able support for this op if below 2.7, the additional clone could cause regressions and increase memory usage.

@tridao
Copy link
Member

tridao commented Apr 22, 2025

We will drop support for pytorch < 2.4 so you can simplify the code.
I'll need to think more about the clone. Does it slow things down when running in eager?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants