Add PT compileable support for flash_attn_with_kvcache #1592
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Continues #1139 adding custom op for flash_attn_with_kvcache.
On a transformers model this improves perf by >2x by avoiding graph breaks. There is a gotcha here, with this implementation an error is thrown in PyTorch 2.6 in user code when reshaping FA output:
This is not an issue for PyTorch 2.7, so I had to introduce conditionalisation to workaround this by returning clone of the output tensors only for PT versions earlier than 2.7 and when compile is being used.