Skip to content

Commit 40fa35a

Browse files
committed
Always enable PackGQA if PagedKV to reduce compilation and bin size
1 parent a84a237 commit 40fa35a

File tree

135 files changed

+75
-720
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

135 files changed

+75
-720
lines changed

hopper/flash_api.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,8 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
265265
SPLIT_SWITCH(params.num_splits > 1, Split, [&] {
266266
PAGEDKV_SWITCH(params.page_table, PagedKV, [&] {
267267
PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] {
268-
// Always enable PackGQA for Sm8x to reduce compilation
269-
static constexpr bool PackGQA = PackGQA_ || Arch < 90;
268+
// Always enable PackGQA for Sm8x or PagedKV to reduce compilation
269+
static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKV;
270270
SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] {
271271
if (!params.is_e4m3) {
272272
if (params.is_bf16) {
@@ -369,9 +369,9 @@ void run_mha_fwd_combine(Flash_fwd_params &params, cudaStream_t stream) {
369369
}
370370

371371
inline bool get_pack_gqa(Flash_fwd_params const& params) {
372-
// Always enable PackGQA for Sm8x to reduce compilation and binary size.
373-
// Has almost no effect on speed.
374-
if (params.arch < 90) { return true; }
372+
// Always enable PackGQA for Sm8x or PagedKV to reduce compilation and binary size.
373+
// Has little effect on speed.
374+
if (params.arch < 90 || params.page_table) { return true; }
375375
#ifdef FLASHATTENTION_DISABLE_PACKGQA
376376
return false;
377377
#else
@@ -838,7 +838,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
838838
TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits.");
839839
#endif
840840
#ifdef FLASHATTENTION_DISABLE_PACKGQA
841-
TORCH_CHECK(params.arch < 90 || !params.pack_gqa, "This flash attention build does not support pack_gqa.");
841+
TORCH_CHECK(!params.pack_gqa || params.arch < 90 || params.page_table, "This flash attention build does not support pack_gqa.");
842842
#endif
843843
#ifdef FLASHATTENTION_DISABLE_PAGEDKV
844844
TORCH_CHECK(!paged_KV, "This flash attention build does not support paged KV.");

hopper/generate_kernels.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,12 @@ class Kernel:
9595
def template(self) -> str:
9696
if self.direction == "fwd":
9797
if self.sm == 90:
98+
# Always enable PackGQA for PagedKV to reduce compilation
99+
packgqa = self.packgqa or self.paged_kv
98100
return KERNEL_IMPL_TEMPLATE_FWD_SM90.format(
99101
ARCH=str(self.sm), DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim,
100102
SPLIT=str(self.split).lower(), PAGEDKV=str(self.paged_kv).lower(),
101-
SOFTCAP=str(self.softcap).lower(), PACKGQA=str(self.packgqa).lower()
103+
SOFTCAP=str(self.softcap).lower(), PACKGQA=str(packgqa).lower()
102104
)
103105
else:
104106
# Always enable PackGQA for Sm8x to reduce compilation
@@ -126,9 +128,9 @@ def filename(self) -> str:
126128

127129
def get_all_kernels() -> List[Kernel]:
128130
for dtype, head_dim, split, paged_kv, softcap, packgqa, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SPLIT, PAGEDKV, SOFTCAP, PACKGQA, SM):
129-
# We always enable PackGQA for Sm8x so we should just pass in packgqa=False
131+
# We always enable PackGQA for Sm8x and PagedKV so we should just pass in packgqa=False
130132
# to avoid the `_packgqa` in the filename.
131-
if sm < 90 and packgqa:
133+
if packgqa and (sm < 90 or (sm >= 90 and paged_kv)):
132134
continue
133135
if sm >= 90 or dtype in DTYPE_MAP_FWD_SM8x:
134136
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, split=split, paged_kv=paged_kv, softcap=softcap, packgqa=packgqa, direction="fwd")

hopper/instantiations/flash_fwd_hdim128_bf16_paged_packgqa_sm90.cu

-9
This file was deleted.

hopper/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
#include "flash_fwd_launch_template.h"
66

77
#ifndef FLASHATTENTION_DISABLE_HDIM128
8-
template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, true, false, false>(Flash_fwd_params &params, cudaStream_t stream);
8+
template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
99
#endif

hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_packgqa_sm90.cu

-9
This file was deleted.

hopper/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
#include "flash_fwd_launch_template.h"
66

77
#ifndef FLASHATTENTION_DISABLE_HDIM128
8-
template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, true, true, false>(Flash_fwd_params &params, cudaStream_t stream);
8+
template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
99
#endif

hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_packgqa_sm90.cu

-9
This file was deleted.

hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
#include "flash_fwd_launch_template.h"
66

77
#ifndef FLASHATTENTION_DISABLE_HDIM128
8-
template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, true, false, false>(Flash_fwd_params &params, cudaStream_t stream);
8+
template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
99
#endif

hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_packgqa_sm90.cu

-9
This file was deleted.

hopper/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
#include "flash_fwd_launch_template.h"
66

77
#ifndef FLASHATTENTION_DISABLE_HDIM128
8-
template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, true, true, false>(Flash_fwd_params &params, cudaStream_t stream);
8+
template void run_mha_fwd_<90, cutlass::bfloat16_t, 128, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
99
#endif

hopper/instantiations/flash_fwd_hdim128_e4m3_paged_packgqa_sm90.cu

-9
This file was deleted.

hopper/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
#include "flash_fwd_launch_template.h"
66

77
#ifndef FLASHATTENTION_DISABLE_HDIM128
8-
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, true, false, false>(Flash_fwd_params &params, cudaStream_t stream);
8+
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
99
#endif

hopper/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_packgqa_sm90.cu

-9
This file was deleted.

hopper/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
#include "flash_fwd_launch_template.h"
66

77
#ifndef FLASHATTENTION_DISABLE_HDIM128
8-
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, true, true, false>(Flash_fwd_params &params, cudaStream_t stream);
8+
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
99
#endif

hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_packgqa_sm90.cu

-9
This file was deleted.

hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
#include "flash_fwd_launch_template.h"
66

77
#ifndef FLASHATTENTION_DISABLE_HDIM128
8-
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, true, false, false>(Flash_fwd_params &params, cudaStream_t stream);
8+
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
99
#endif

hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_packgqa_sm90.cu

-9
This file was deleted.

hopper/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
#include "flash_fwd_launch_template.h"
66

77
#ifndef FLASHATTENTION_DISABLE_HDIM128
8-
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, true, true, false>(Flash_fwd_params &params, cudaStream_t stream);
8+
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 128, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
99
#endif

hopper/instantiations/flash_fwd_hdim128_fp16_paged_packgqa_sm90.cu

-9
This file was deleted.

hopper/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
#include "flash_fwd_launch_template.h"
66

77
#ifndef FLASHATTENTION_DISABLE_HDIM128
8-
template void run_mha_fwd_<90, cutlass::half_t, 128, false, true, false, false>(Flash_fwd_params &params, cudaStream_t stream);
8+
template void run_mha_fwd_<90, cutlass::half_t, 128, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
99
#endif

hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_packgqa_sm90.cu

-9
This file was deleted.

hopper/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
#include "flash_fwd_launch_template.h"
66

77
#ifndef FLASHATTENTION_DISABLE_HDIM128
8-
template void run_mha_fwd_<90, cutlass::half_t, 128, false, true, true, false>(Flash_fwd_params &params, cudaStream_t stream);
8+
template void run_mha_fwd_<90, cutlass::half_t, 128, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
99
#endif

hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_packgqa_sm90.cu

-9
This file was deleted.

hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
#include "flash_fwd_launch_template.h"
66

77
#ifndef FLASHATTENTION_DISABLE_HDIM128
8-
template void run_mha_fwd_<90, cutlass::half_t, 128, true, true, false, false>(Flash_fwd_params &params, cudaStream_t stream);
8+
template void run_mha_fwd_<90, cutlass::half_t, 128, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
99
#endif

hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_packgqa_sm90.cu

-9
This file was deleted.

hopper/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
#include "flash_fwd_launch_template.h"
66

77
#ifndef FLASHATTENTION_DISABLE_HDIM128
8-
template void run_mha_fwd_<90, cutlass::half_t, 128, true, true, true, false>(Flash_fwd_params &params, cudaStream_t stream);
8+
template void run_mha_fwd_<90, cutlass::half_t, 128, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
99
#endif

hopper/instantiations/flash_fwd_hdim192_bf16_paged_packgqa_sm90.cu

-9
This file was deleted.

hopper/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
#include "flash_fwd_launch_template.h"
66

77
#ifndef FLASHATTENTION_DISABLE_HDIM192
8-
template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, true, false, false>(Flash_fwd_params &params, cudaStream_t stream);
8+
template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
99
#endif

hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_packgqa_sm90.cu

-9
This file was deleted.

hopper/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
#include "flash_fwd_launch_template.h"
66

77
#ifndef FLASHATTENTION_DISABLE_HDIM192
8-
template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, true, true, false>(Flash_fwd_params &params, cudaStream_t stream);
8+
template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
99
#endif

hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_packgqa_sm90.cu

-9
This file was deleted.

hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
#include "flash_fwd_launch_template.h"
66

77
#ifndef FLASHATTENTION_DISABLE_HDIM192
8-
template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, true, false, false>(Flash_fwd_params &params, cudaStream_t stream);
8+
template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
99
#endif

hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_packgqa_sm90.cu

-9
This file was deleted.

hopper/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
#include "flash_fwd_launch_template.h"
66

77
#ifndef FLASHATTENTION_DISABLE_HDIM192
8-
template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, true, true, false>(Flash_fwd_params &params, cudaStream_t stream);
8+
template void run_mha_fwd_<90, cutlass::bfloat16_t, 192, true, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
99
#endif

hopper/instantiations/flash_fwd_hdim192_e4m3_paged_packgqa_sm90.cu

-9
This file was deleted.

hopper/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
#include "flash_fwd_launch_template.h"
66

77
#ifndef FLASHATTENTION_DISABLE_HDIM192
8-
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, true, false, false>(Flash_fwd_params &params, cudaStream_t stream);
8+
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
99
#endif

hopper/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_packgqa_sm90.cu

-9
This file was deleted.

hopper/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
#include "flash_fwd_launch_template.h"
66

77
#ifndef FLASHATTENTION_DISABLE_HDIM192
8-
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, true, true, false>(Flash_fwd_params &params, cudaStream_t stream);
8+
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, false, true, true, true>(Flash_fwd_params &params, cudaStream_t stream);
99
#endif

hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_packgqa_sm90.cu

-9
This file was deleted.

hopper/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
#include "flash_fwd_launch_template.h"
66

77
#ifndef FLASHATTENTION_DISABLE_HDIM192
8-
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, true, true, false, false>(Flash_fwd_params &params, cudaStream_t stream);
8+
template void run_mha_fwd_<90, cutlass::float_e4m3_t, 192, true, true, false, true>(Flash_fwd_params &params, cudaStream_t stream);
99
#endif

0 commit comments

Comments
 (0)