Skip to content

Commit 9063456

Browse files
ggerganovfuryhawk
authored andcommitted
metal : use F32 accumulators in FA kernels (ggml-org#13975)
ggml-ci
1 parent ccb2d87 commit 9063456

File tree

2 files changed

+57
-45
lines changed

2 files changed

+57
-45
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4766,14 +4766,16 @@ static bool ggml_metal_encode_node(
47664766
GGML_ASSERT(nqptg % 8 == 0);
47674767
GGML_ASSERT(ncpsg % 32 == 0);
47684768

4769+
const int is_q = ggml_is_quantized(src1->type) ? 1 : 0;
4770+
47694771
// 2*(2*ncpsg + nqptg)*(nsg)
47704772
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
47714773
//
47724774
// 16*32*(nsg)
47734775
// the shared memory needed for the simdgroups to load the KV cache
47744776
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
47754777
//
4776-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
4778+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(2*ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + is_q*(16*32*(nsg)))*(sizeof(float)/2), 16))
47774779

47784780
int64_t nsgmax = 2;
47794781

@@ -4810,9 +4812,9 @@ static bool ggml_metal_encode_node(
48104812
// and store the soft_max values and the mask
48114813
//
48124814
// ne00*(nsg)
4813-
// each simdgroup has a full f16 head vector in shared mem to accumulate results
4815+
// each simdgroup has a full f32 head vector in shared mem to accumulate results
48144816
//
4815-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + ne20*(nsg))*(sizeof(float)/2), 16))
4817+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
48164818

48174819
int64_t nsgmax = 2;
48184820
while (true) {

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3328,14 +3328,14 @@ kernel void kernel_flash_attn_ext(
33283328
constexpr short NW = N_SIMDWIDTH;
33293329
constexpr short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
33303330

3331-
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
3332-
const short T = DK + 2*TS; // shared memory size per query in (half)
3331+
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
3332+
const short T = 2*DK + 2*TS; // shared memory size per query in (half)
33333333

3334-
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3335-
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3336-
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
3337-
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
3338-
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*DK); // scratch buffer for attention, mask and diagonal matrix
3334+
threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3335+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3336+
threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*DK); // reuse query data for accumulation
3337+
threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*DK); // same as above but in o4_t
3338+
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + 2*Q*DK); // scratch buffer for attention, mask and diagonal matrix
33393339

33403340
threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
33413341
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
@@ -3354,7 +3354,7 @@ kernel void kernel_flash_attn_ext(
33543354
if (iq1 + j < args.ne01) {
33553355
sq4[j*DK4 + i] = (q4_t) q4[i];
33563356
} else {
3357-
sq4[j*DK4 + i] = (q4_t) 0.0f;
3357+
sq4[j*DK4 + i] = 0;
33583358
}
33593359
}
33603360
}
@@ -3634,9 +3634,6 @@ kernel void kernel_flash_attn_ext(
36343634

36353635
// reduce the warps sequentially
36363636
for (ushort sg = 1; sg < nsg; ++sg) {
3637-
float S = { 0.0f };
3638-
float M = { -__FLT_MAX__/2 };
3639-
36403637
threadgroup_barrier(mem_flags::mem_threadgroup);
36413638

36423639
// each simdgroup stores its output to shared memory, reusing sq
@@ -3657,12 +3654,12 @@ kernel void kernel_flash_attn_ext(
36573654
const float M0 = ss[j*TS + 1];
36583655
const float M1 = ss[j*TS + sg*SH + 1];
36593656

3660-
M = max(M0, M1);
3657+
const float M = max(M0, M1);
36613658

36623659
const float ms0 = exp(M0 - M);
36633660
const float ms1 = exp(M1 - M);
36643661

3665-
S = S0*ms0 + S1*ms1;
3662+
const float S = S0*ms0 + S1*ms1;
36663663

36673664
if (tiisg == 0) {
36683665
ss[j*TS + 0] = S;
@@ -3701,16 +3698,18 @@ kernel void kernel_flash_attn_ext(
37013698
}
37023699
}
37033700

3704-
device float4 * dst4 = (device float4 *) dst;
3701+
threadgroup_barrier(mem_flags::mem_threadgroup);
3702+
3703+
threadgroup s_t * sf = (threadgroup s_t *) (shmem_f16 + 2*Q*DK);
37053704

37063705
// final rescale with 1/S and store to global memory
3707-
if (sgitg == 0) {
3708-
for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
3709-
const float S = ss[j*TS + 0];
3706+
for (short j = sgitg; j < Q && iq1 + j < args.ne01; j += nsg) {
3707+
const float S = 1.0f/sf[j*TS + 0];
37103708

3711-
for (short i = tiisg; i < DV4; i += NW) {
3712-
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4 + i] = (float4) so4[j*DV4 + i]/S;
3713-
}
3709+
device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
3710+
3711+
for (short i = tiisg; i < DV4; i += NW) {
3712+
dst4[i] = (float4) so4[j*DV4 + i]*S;
37143713
}
37153714
}
37163715
}
@@ -3719,12 +3718,22 @@ kernel void kernel_flash_attn_ext(
37193718
// template to be able to explore different combinations
37203719
//
37213720
#define FA_TYPES \
3722-
half, half4, simdgroup_half8x8, \
3723-
half, half4x4, simdgroup_half8x8, \
3724-
half, half4x4, simdgroup_half8x8, \
3725-
float, simdgroup_float8x8, \
3726-
float, simdgroup_float8x8, \
3727-
half, half4, simdgroup_half8x8
3721+
float, float4, simdgroup_float8x8, \
3722+
half, half4x4, simdgroup_half8x8, \
3723+
half, half4x4, simdgroup_half8x8, \
3724+
float, simdgroup_float8x8, \
3725+
float, simdgroup_float8x8, \
3726+
float, float4, simdgroup_float8x8
3727+
//half, half4, simdgroup_half8x8
3728+
3729+
#define FA_TYPES_BF \
3730+
bfloat, bfloat4, simdgroup_bfloat8x8, \
3731+
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
3732+
bfloat, bfloat4x4, simdgroup_bfloat8x8, \
3733+
float, simdgroup_float8x8, \
3734+
float, simdgroup_float8x8, \
3735+
float, float4, simdgroup_float8x8
3736+
//half, half4, simdgroup_half8x8
37283737

37293738
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
37303739

@@ -3739,15 +3748,15 @@ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_at
37393748
template [[host_name("kernel_flash_attn_ext_f16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
37403749

37413750
#if defined(GGML_METAL_USE_BF16)
3742-
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
3743-
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
3744-
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
3745-
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
3746-
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
3747-
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
3748-
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
3749-
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
3750-
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
3751+
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
3752+
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
3753+
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96, 96>;
3754+
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112, 112>;
3755+
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128, 128>;
3756+
template [[host_name("kernel_flash_attn_ext_bf16_h192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>;
3757+
template [[host_name("kernel_flash_attn_ext_bf16_hk192_hv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>;
3758+
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>;
3759+
template [[host_name("kernel_flash_attn_ext_bf16_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
37513760
#endif
37523761

37533762
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
@@ -3801,6 +3810,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_at
38013810
template [[host_name("kernel_flash_attn_ext_q8_0_hk576_hv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>;
38023811

38033812
#undef FA_TYPES
3813+
#undef FA_TYPES_BF
38043814

38053815
template<
38063816
typename q4_t, // query types in shared memory
@@ -3847,12 +3857,12 @@ kernel void kernel_flash_attn_ext_vec(
38473857

38483858
const short T = DK + nsg*SH; // shared memory size per query in (half)
38493859

3850-
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3851-
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3852-
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3853-
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3854-
threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
3855-
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
3860+
//threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*DK); // holds the query data
3861+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*DK); // same as above but in q4_t
3862+
threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*DK); // scratch buffer for attention
3863+
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*DK); // same as above but in s4_t
3864+
threadgroup float * sm = (threadgroup float *) (shmem_f16 + sgitg*SH + 2*C + Q*DK); // scratch buffer for mask
3865+
threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*DV + Q*T); // scratch buffer for the results
38563866

38573867
// store the result for all queries in local memory (the O matrix from the paper)
38583868
o4_t lo[DV4/NL];
@@ -4157,7 +4167,7 @@ kernel void kernel_flash_attn_ext_vec(
41574167
half4, \
41584168
float, \
41594169
float, float4, \
4160-
half4
4170+
float4
41614171

41624172
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
41634173

0 commit comments

Comments
 (0)