Skip to content

Flash MLA (CPU only) #240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion ggml/src/ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "ggml-quants.h"
#include "ggml-impl.h"
#if GGML_USE_IQK_MULMAT
#include "iqk/iqk_config.h"
#include "iqk/iqk_mul_mat.h"
#include "iqk/iqk_quantize.h"
#endif
Expand Down Expand Up @@ -5449,7 +5450,12 @@ void ggml_vec_dot_q6_0_q8_0(int n, float * restrict s, size_t bs, const void * r

void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
#ifdef HAVE_FANCY_SIMD
enum ggml_type dot_type = GGML_TYPE_Q8_1_X4;
#else
enum ggml_type dot_type = GGML_TYPE_Q8_0_X4;
#endif
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, dot_type, vy, by, s, bs, 0, 1)) {
return;
}
#endif
Expand Down
6 changes: 4 additions & 2 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -10451,7 +10451,7 @@ static void ggml_compute_forward_dup_bytes(
ne00 == ne0 &&
nb00 == type_size && nb0 == type_size) {
// copy by rows
const size_t rs = ne00 * type_size;
const size_t rs = ggml_row_size(src0->type, ne00); //ne00 * type_size;
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ir0; i01 < ir1; i01++) {
Expand Down Expand Up @@ -17871,6 +17871,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(

#if GGML_USE_IQK_MULMAT
if (max_bias <= 0.0f && q->type == GGML_TYPE_F32 && mask && mask->type == GGML_TYPE_F16) {
//if (ith == 0) printf("k: %ld x %ld x %ld, q: %ld x %ld x %ld, v: %ld x %ld x %ld mask: %ld x %ld x %ld\n",
// k->ne[0], k->ne[1], k->ne[2], q->ne[0], q->ne[1], q->ne[2], v->ne[0], v->ne[1], v->ne[2], mask->ne[0], mask->ne[1], mask->ne[2]);
// I keep changing my mind what is the best strategy to split the threads when processing
// multiple heads. This is my current thinking, the commented out code below was the previous.
int ntg = nth/simple_gcd(neq2*neq3, nth);
Expand Down Expand Up @@ -17906,8 +17908,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
}
return;
IQK_Flash_Attn_NotAvailable:;
printf("iqk_flash was rejected\n");
}

#endif

const uint32_t n_head = neq2;
Expand Down
94 changes: 83 additions & 11 deletions ggml/src/iqk/iqk_mul_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15016,7 +15016,7 @@ template <int k_step>
struct BaseHelper {
BaseHelper(const char * data, int stride) : data(data), block(data), stride(stride) {}

inline void set_block(int k1) { block = data + k1*k_step*stride; }
//inline void set_block(int k1) { block = data + k1*k_step*stride; }
inline void reset_block() { block = data; }
inline void next_block() { block += k_step*stride; }
inline const char * lblock(int l1) const { return block + l1*stride; }
Expand Down Expand Up @@ -16038,9 +16038,9 @@ struct FlashQKV {
}

inline void normalize_and_store(const FlashMS<q_step, k_step>& fms, int j, const qkv_cache_t * R, float * qkv) const {
//GGML_ASSERT(fms.S[j] > 0);
//auto norm = F16::set1(1/fms.S[j]);
auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
GGML_ASSERT(fms.S[j] > 0);
auto norm = F16::set1(1/fms.S[j]);
//auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
for (int i = 0; i < D/F16::block_size; ++i) {
auto r = F16::load(R + F16::block_size*i);
F16::store(qkv + F16::block_size*i, F16::mul(norm, r));
Expand Down Expand Up @@ -16076,7 +16076,7 @@ struct FlashQKV {

template <int D, int q_step, int k_step>
struct FlashQKfp32 {
static_assert(D%F16::block_size == 0 && D <= 256);
static_assert(D%F16::block_size == 0 && D <= 576);
static_assert(k_step%F16::block_size == 0);
static_assert(q_step <= 4 || q_step%4 == 0);

Expand Down Expand Up @@ -16571,8 +16571,8 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
// q_step-1 versions of these functions for us, which I though was too much with q_step = 8.
template <int Dk, int Dv, int q_step, int k_step>
struct FlashAttn {
static_assert(Dk%F16::block_size == 0 && Dk <= 256);
static_assert(Dv%F16::block_size == 0 && Dv <= 256);
static_assert(Dk%F16::block_size == 0 && Dk <= 576);
static_assert(Dv%F16::block_size == 0 && Dv <= 512);
static_assert(k_step%F16::block_size == 0);
static_assert(q_step <= 4 || q_step%4 == 0);

Expand Down Expand Up @@ -16665,7 +16665,8 @@ struct HelperBF16 final : public BaseHelper<step> {

template <int D, int q_step, int k_step>
struct FlashQKbf16 {
static_assert(D%32 == 0 && D <= 256);
//static_assert(D%32 == 0 && D <= 256);
static_assert(D%32 == 0 && D <= 576);
static_assert(k_step%32 == 0);
static_assert(q_step <= 4 || q_step%4 == 0);

Expand Down Expand Up @@ -16975,8 +16976,10 @@ struct FlashQKbf16 {

template <int Dk, int Dv, int q_step, int k_step>
struct FlashAttnBF16 {
static_assert(Dk%32 == 0 && Dk <= 256);
static_assert(Dv%32 == 0 && Dv <= 256);
//static_assert(Dk%32 == 0 && Dk <= 256);
//static_assert(Dv%32 == 0 && Dv <= 256);
static_assert(Dk%32 == 0 && Dk <= 576);
static_assert(Dv%32 == 0 && Dv <= 512);
static_assert(k_step%32 == 0);
static_assert(q_step <= 4 || q_step%4 == 0);

Expand Down Expand Up @@ -17216,6 +17219,66 @@ inline bool flash_attn_is_supported(ggml_type type) {
#endif
return false;
}

template <int step_k, typename KHelper, typename VHelper>
inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float scale, float softcap, float * qkv) {
if (nq1 % 8 == 0) {
FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv);
} else {
FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv);
}
}

template <int step_k>
inline bool iqk_deepseek_helper(ggml_type type_k,
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * k, const char * v, const char * mask,
float scale, float softcap, float * qkv) {
if (type_k == GGML_TYPE_Q8_0) {
HelperQ80<576, step_k> kh((const char *)k, stride_k);
HelperQ80<512, step_k> vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
return true;
}
if (type_k == GGML_TYPE_Q6_0) {
HelperQ60<576, step_k> kh((const char *)k, stride_k);
HelperQ60<512, step_k> vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
return true;
}
if (type_k == GGML_TYPE_Q8_KV) {
HelperQ8KV<576, step_k> kh((const char *)k, stride_k);
HelperQ8KV<512, step_k> vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
return true;
}
if (type_k == GGML_TYPE_F16) {
HelperF16<576, step_k> kh((const char *)k, stride_k);
HelperF16<512, step_k> vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv);
return true;
}
#ifdef __AVX512BF16__
if (type_k == GGML_TYPE_BF16) {
HelperBF16<576, step_k> kh((const char *)k, stride_k);
HelperBF16<512, step_k> vh((const char *)v, stride_v);
if (nq1 % 8 == 0) {
FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
} else {
FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv);
}
return true;
}
#endif
return false;
}

}

bool iqk_flash_attn_noalibi(int int_type_k, // type of k
Expand All @@ -17237,10 +17300,19 @@ bool iqk_flash_attn_noalibi(int int_type_k, // type of k
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
float * qkv) { // v*softmax(scale*(k*q))

if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32

auto type_k = ggml_type(int_type_k);
auto type_v = ggml_type(int_type_v);

if (Dk == 576 && Dv == 512) {
GGML_ASSERT(type_k == type_v);
stride_q /= sizeof(float); // q stride as float
return iqk_deepseek_helper<32>(type_k, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv);
}

if (!flash_attn_is_supported(type_k) || !flash_attn_is_supported(type_v)) return false;
if (!mask || nk1%32 != 0) return false; // the implementation assumes mask is not null and nk is a multiple of 32
if (Dk != Dv && Dk != 192 && Dv != 128) return false;
if (Dv != 64 && Dv != 96 && Dv != 128 && Dv != 256) return false;
if (Dk != 64 && Dk != 96 && Dk != 128 && Dk != 192 && Dv != 256) return false;
Expand Down
Loading