@@ -3328,14 +3328,14 @@ kernel void kernel_flash_attn_ext(
3328
3328
constexpr short NW = N_SIMDWIDTH;
3329
3329
constexpr short SH = (2 *C + Q); // shared memory per simdgroup (s_t == float)
3330
3330
3331
- const short TS = nsg*SH; // shared memory size per query in (s_t == float)
3332
- const short T = DK + 2 *TS; // shared memory size per query in (half)
3331
+ const short TS = nsg*SH; // shared memory size per query in (s_t == float)
3332
+ const short T = 2 * DK + 2 *TS; // shared memory size per query in (half)
3333
3333
3334
- threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0 *DK); // holds the query data
3335
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0 *DK); // same as above but in q4_t
3336
- threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0 *DK); // reuse query data for accumulation
3337
- threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0 *DK); // same as above but in o4_t
3338
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2 *sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
3334
+ threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0 *DK); // holds the query data
3335
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0 *DK); // same as above but in q4_t
3336
+ threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0 *DK); // reuse query data for accumulation
3337
+ threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0 *DK); // same as above but in o4_t
3338
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2 *sgitg*SH + 2 * Q*DK); // scratch buffer for attention, mask and diagonal matrix
3339
3339
3340
3340
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4 *16 *KV) + Q*T); // scratch buffer to load K in shared memory
3341
3341
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4 *16 *KV) + Q*T); // same as above but in k4x4_t
@@ -3354,7 +3354,7 @@ kernel void kernel_flash_attn_ext(
3354
3354
if (iq1 + j < args.ne01 ) {
3355
3355
sq4[j*DK4 + i] = (q4_t ) q4[i];
3356
3356
} else {
3357
- sq4[j*DK4 + i] = ( q4_t ) 0 . 0f ;
3357
+ sq4[j*DK4 + i] = 0 ;
3358
3358
}
3359
3359
}
3360
3360
}
@@ -3634,9 +3634,6 @@ kernel void kernel_flash_attn_ext(
3634
3634
3635
3635
// reduce the warps sequentially
3636
3636
for (ushort sg = 1 ; sg < nsg; ++sg) {
3637
- float S = { 0 .0f };
3638
- float M = { -__FLT_MAX__/2 };
3639
-
3640
3637
threadgroup_barrier (mem_flags::mem_threadgroup);
3641
3638
3642
3639
// each simdgroup stores its output to shared memory, reusing sq
@@ -3657,12 +3654,12 @@ kernel void kernel_flash_attn_ext(
3657
3654
const float M0 = ss[j*TS + 1 ];
3658
3655
const float M1 = ss[j*TS + sg*SH + 1 ];
3659
3656
3660
- M = max (M0, M1);
3657
+ const float M = max (M0, M1);
3661
3658
3662
3659
const float ms0 = exp (M0 - M);
3663
3660
const float ms1 = exp (M1 - M);
3664
3661
3665
- S = S0*ms0 + S1*ms1;
3662
+ const float S = S0*ms0 + S1*ms1;
3666
3663
3667
3664
if (tiisg == 0 ) {
3668
3665
ss[j*TS + 0 ] = S;
@@ -3701,16 +3698,18 @@ kernel void kernel_flash_attn_ext(
3701
3698
}
3702
3699
}
3703
3700
3704
- device float4 * dst4 = (device float4 *) dst;
3701
+ threadgroup_barrier (mem_flags::mem_threadgroup);
3702
+
3703
+ threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2 *Q*DK);
3705
3704
3706
3705
// final rescale with 1/S and store to global memory
3707
- if (sgitg == 0 ) {
3708
- for (short j = 0 ; j < Q && iq1 + j < args.ne01 ; ++j) {
3709
- const float S = ss[j*TS + 0 ];
3706
+ for (short j = sgitg; j < Q && iq1 + j < args.ne01 ; j += nsg) {
3707
+ const float S = 1 .0f /sf[j*TS + 0 ];
3710
3708
3711
- for (short i = tiisg; i < DV4; i += NW) {
3712
- dst4[((uint64_t )iq3*args.ne2 *args.ne1 + iq2 + (uint64_t )(iq1 + j)*args.ne1 )*DV4 + i] = (float4) so4[j*DV4 + i]/S;
3713
- }
3709
+ device float4 * dst4 = (device float4 *) dst + ((uint64_t )iq3*args.ne2 *args.ne1 + iq2 + (uint64_t )(iq1 + j)*args.ne1 )*DV4;
3710
+
3711
+ for (short i = tiisg; i < DV4; i += NW) {
3712
+ dst4[i] = (float4) so4[j*DV4 + i]*S;
3714
3713
}
3715
3714
}
3716
3715
}
@@ -3719,12 +3718,22 @@ kernel void kernel_flash_attn_ext(
3719
3718
// template to be able to explore different combinations
3720
3719
//
3721
3720
#define FA_TYPES \
3722
- half, half4, simdgroup_half8x8, \
3723
- half, half4x4, simdgroup_half8x8, \
3724
- half, half4x4, simdgroup_half8x8, \
3725
- float , simdgroup_float8x8, \
3726
- float , simdgroup_float8x8, \
3727
- half, half4, simdgroup_half8x8
3721
+ float , float4, simdgroup_float8x8, \
3722
+ half, half4x4, simdgroup_half8x8, \
3723
+ half, half4x4, simdgroup_half8x8, \
3724
+ float , simdgroup_float8x8, \
3725
+ float , simdgroup_float8x8, \
3726
+ float , float4, simdgroup_float8x8
3727
+ // half, half4, simdgroup_half8x8
3728
+
3729
+ #define FA_TYPES_BF \
3730
+ bfloat, bfloat4, simdgroup_bfloat8x8, \
3731
+ bfloat, bfloat4x4, simdgroup_bfloat8x8, \
3732
+ bfloat, bfloat4x4, simdgroup_bfloat8x8, \
3733
+ float , simdgroup_float8x8, \
3734
+ float , simdgroup_float8x8, \
3735
+ float , float4, simdgroup_float8x8
3736
+ // half, half4, simdgroup_half8x8
3728
3737
3729
3738
typedef decltype (kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 64 , 64 >) flash_attn_ext_t;
3730
3739
@@ -3739,15 +3748,15 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at
3739
3748
template [[host_name(" kernel_flash_attn_ext_f16_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 576 , 512 >;
3740
3749
3741
3750
#if defined(GGML_METAL_USE_BF16)
3742
- template [[host_name(" kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 64 , 64 >;
3743
- template [[host_name(" kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 80 , 80 >;
3744
- template [[host_name(" kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 96 , 96 >;
3745
- template [[host_name(" kernel_flash_attn_ext_bf16_h112" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 112 , 112 >;
3746
- template [[host_name(" kernel_flash_attn_ext_bf16_h128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 128 , 128 >;
3747
- template [[host_name(" kernel_flash_attn_ext_bf16_h192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 192 , 192 >;
3748
- template [[host_name(" kernel_flash_attn_ext_bf16_hk192_hv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 192 , 128 >;
3749
- template [[host_name(" kernel_flash_attn_ext_bf16_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 256 , 256 >;
3750
- template [[host_name(" kernel_flash_attn_ext_bf16_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 576 , 512 >;
3751
+ template [[host_name(" kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 64 , 64 >;
3752
+ template [[host_name(" kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 80 , 80 >;
3753
+ template [[host_name(" kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 96 , 96 >;
3754
+ template [[host_name(" kernel_flash_attn_ext_bf16_h112" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 112 , 112 >;
3755
+ template [[host_name(" kernel_flash_attn_ext_bf16_h128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 128 , 128 >;
3756
+ template [[host_name(" kernel_flash_attn_ext_bf16_h192" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 192 , 192 >;
3757
+ template [[host_name(" kernel_flash_attn_ext_bf16_hk192_hv128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 192 , 128 >;
3758
+ template [[host_name(" kernel_flash_attn_ext_bf16_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 256 , 256 >;
3759
+ template [[host_name(" kernel_flash_attn_ext_bf16_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF , bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 576 , 512 >;
3751
3760
#endif
3752
3761
3753
3762
template [[host_name(" kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 64 , 64 >;
@@ -3801,6 +3810,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_at
3801
3810
template [[host_name(" kernel_flash_attn_ext_q8_0_hk576_hv512" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 576 , 512 >;
3802
3811
3803
3812
#undef FA_TYPES
3813
+ #undef FA_TYPES_BF
3804
3814
3805
3815
template <
3806
3816
typename q4_t , // query types in shared memory
@@ -3847,12 +3857,12 @@ kernel void kernel_flash_attn_ext_vec(
3847
3857
3848
3858
const short T = DK + nsg*SH; // shared memory size per query in (half)
3849
3859
3850
- // threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3851
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0 *DK); // same as above but in q4_t
3852
- threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3853
- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3854
- threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2 *C + Q*DK); // scratch buffer for mask
3855
- threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
3860
+ // threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3861
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0 *DK); // same as above but in q4_t
3862
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3863
+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3864
+ threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2 *C + Q*DK); // scratch buffer for mask
3865
+ threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2 * sgitg*DV + Q*T); // scratch buffer for the results
3856
3866
3857
3867
// store the result for all queries in local memory (the O matrix from the paper)
3858
3868
o4_t lo[DV4/NL];
@@ -4157,7 +4167,7 @@ kernel void kernel_flash_attn_ext_vec(
4157
4167
half4, \
4158
4168
float , \
4159
4169
float , float4, \
4160
- half4
4170
+ float4
4161
4171
4162
4172
typedef decltype (kernel_flash_attn_ext_vec<FA_TYPES, half4, 1 , dequantize_f16_t4, half4, 1 , dequantize_f16_t4, 128 , 128 , 4 >) flash_attn_ext_vec_t;
4163
4173
0 commit comments