Skip to content

Add PT compileable support for flash_attn_with_kvcache #1592

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 139 additions & 23 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# Copyright (c) 2023, Tri Dao.

from typing import Optional, Sequence, Tuple, Union
from functools import lru_cache

import torch
import torch.nn as nn
import os

import torch._dynamo as dynamo

# isort: off
# We need to import the CUDA kernels after importing torch
USE_TRITON_ROCM = os.getenv("FLASH_ATTENTION_TRITON_AMD_ENABLE", "FALSE") == "TRUE"
Expand All @@ -14,8 +17,25 @@
else:
import flash_attn_2_cuda as flash_attn_gpu


# Add pytorch custom ops for PT 2.4 and beyond.
USE_PT_COMPILE_OPS = (torch.__version__ >= "2.4.0")


# To avoid weak reference error observed when reshaping output tensors
# in kvcache implementation. Not an issue in PT 2.7.
USE_CLONE_FOR_DYNAMO = (torch.__version__ < "2.7.0")


# isort: on

@lru_cache(maxsize=None)
def _is_torch_compiling():
try:
return dynamo.is_compiling()
except Exception:
return False

def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x

Expand Down Expand Up @@ -138,7 +158,7 @@ def _flash_attn_forward_fake(
return out, softmax_lse, p, rng_state


if torch.__version__ >= "2.4.0":
if USE_PT_COMPILE_OPS:
_wrapped_flash_attn_forward = torch.ops.flash_attn._flash_attn_forward
else:
_wrapped_flash_attn_forward = _flash_attn_forward
Expand Down Expand Up @@ -233,7 +253,7 @@ def _flash_attn_varlen_forward_fake(
return out, softmax_lse, p, rng_state


if torch.__version__ >= "2.4.0":
if USE_PT_COMPILE_OPS:
_wrapped_flash_attn_varlen_forward = torch.ops.flash_attn._flash_attn_varlen_forward
else:
_wrapped_flash_attn_varlen_forward = _flash_attn_varlen_forward
Expand Down Expand Up @@ -325,7 +345,7 @@ def _flash_attn_backward_fake(
return softmax_d


if torch.__version__ >= "2.4.0":
if USE_PT_COMPILE_OPS:
_wrapped_flash_attn_backward = torch.ops.flash_attn._flash_attn_backward
else:
_wrapped_flash_attn_backward = _flash_attn_backward
Expand Down Expand Up @@ -436,7 +456,7 @@ def _flash_attn_varlen_backward_fake(
return softmax_d


if torch.__version__ >= "2.4.0":
if USE_PT_COMPILE_OPS:
_wrapped_flash_attn_varlen_backward = torch.ops.flash_attn._flash_attn_varlen_backward
else:
_wrapped_flash_attn_varlen_backward = _flash_attn_varlen_backward
Expand Down Expand Up @@ -1465,27 +1485,118 @@ def flash_attn_varlen_func(
torch.is_grad_enabled(),
)

@_torch_custom_op_wrapper(
"flash_attn::_flash_attn_with_kvcache",
mutates_args=("k_cache", "v_cache"),
device_types="cuda"
)
def _flash_attn_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
k: Optional[torch.Tensor],
v: Optional[torch.Tensor],
cache_seqlens: Optional[torch.Tensor],
rotary_cos: Optional[torch.Tensor],
rotary_sin: Optional[torch.Tensor],
cache_batch_idx: Optional[torch.Tensor],
cache_leftpad: Optional[torch.Tensor],
block_table: Optional[torch.Tensor],
alibi_slopes: Optional[torch.Tensor],
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size_left: int = -1,
window_size_right: int = -1,
softcap: float = 0.0,
rotary_interleaved: bool = True,
num_splits: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Return:
out: (batch_size, seqlen, nheads, headdim).
softmax_lse: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
return flash_attn_gpu.fwd_kvcache(
q,
k_cache,
v_cache,
k,
v,
cache_seqlens,
rotary_cos,
rotary_sin,
cache_batch_idx,
cache_leftpad,
block_table,
alibi_slopes,
None,
softmax_scale,
causal,
window_size_left,
window_size_right,
softcap,
rotary_interleaved,
num_splits,
)


@torch.library.register_fake("flash_attn::_flash_attn_with_kvcache")
def _flash_attn_with_kvcache_fake(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
k: Optional[torch.Tensor],
v: Optional[torch.Tensor],
cache_seqlens: Optional[torch.Tensor],
rotary_cos: Optional[torch.Tensor],
rotary_sin: Optional[torch.Tensor],
cache_batch_idx: Optional[torch.Tensor],
cache_leftpad: Optional[torch.Tensor],
block_table: Optional[torch.Tensor],
alibi_slopes: Optional[torch.Tensor],
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size_left: int = -1,
window_size_right: int = -1,
softcap: float = 0.0,
rotary_interleaved: bool = True,
num_splits: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seqlen_q, num_heads, head_size = q.shape
out = torch.empty_like(q)
softmax_lse = torch.empty((batch_size, num_heads, seqlen_q), dtype=torch.float32, device=q.device, layout=q.layout)

return out, softmax_lse


if USE_PT_COMPILE_OPS:
_wrapped_flash_attn_with_kvcache = torch.ops.flash_attn._flash_attn_with_kvcache
else:
_wrapped_flash_attn_with_kvcache = _flash_attn_with_kvcache


def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None,
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
k: Optional[torch.Tensor] = None,
v: Optional[torch.Tensor] = None,
rotary_cos: Optional[torch.Tensor] = None,
rotary_sin: Optional[torch.Tensor] = None,
cache_seqlens: Optional[Union[int, torch.Tensor]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
block_table: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
softcap=0.0, # 0.0 means deactivated
rotary_interleaved=True,
alibi_slopes=None,
num_splits=0,
return_softmax_lse=False,
softmax_scale: Optional[float] = None,
causal: bool = False,
window_size: Sequence[int] = (-1, -1), # -1 means infinite context window
softcap: float = 0.0, # 0.0 means deactivated
rotary_interleaved: bool = True,
alibi_slopes: Optional[torch.Tensor] = None,
num_splits: int = 0,
return_softmax_lse: bool = False,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
Expand Down Expand Up @@ -1586,7 +1697,8 @@ def flash_attn_with_kvcache(
cache_seqlens = maybe_contiguous(cache_seqlens)
cache_batch_idx = maybe_contiguous(cache_batch_idx)
block_table = maybe_contiguous(block_table)
out, softmax_lse = flash_attn_gpu.fwd_kvcache(

result = _wrapped_flash_attn_with_kvcache(
q,
k_cache,
v_cache,
Expand All @@ -1599,7 +1711,6 @@ def flash_attn_with_kvcache(
cache_leftpad,
block_table,
alibi_slopes,
None,
softmax_scale,
causal,
window_size[0],
Expand All @@ -1608,4 +1719,9 @@ def flash_attn_with_kvcache(
rotary_interleaved,
num_splits,
)
return (out, softmax_lse) if return_softmax_lse else out

if _is_torch_compiling():
if USE_CLONE_FOR_DYNAMO:
return result.clone() if return_softmax_lse else result[0].clone()

return result if return_softmax_lse else result[0]