Skip to content

Add optional MLA #188

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Feb 9, 2025
Merged

Add optional MLA #188

merged 6 commits into from
Feb 9, 2025

Conversation

ikawrakow
Copy link
Owner

@ikawrakow ikawrakow commented Feb 6, 2025

This PR is derived from #180. The difference to #180 is that MLA is made optional. It is off by default, and can be turned on using the added -mla or --use-mla command line option.

Rationale: MLA improves TG speed, especially when there is a long context. But it also makes prompt processing significantly slower. Hence, MLA is made optional since advantage/disadvantage is use case dependent.

Being able to select or deselect MLA at run time is possible due to the fact that #180 leaves the original wkv_b tensor and its decomposition into wk_b and wv_b in the model. This is somewhat wasteful, but these tensors are not very large and now come handy to easily select between the two attention implementations.

In addition:

  • It is now possible to use a model converted without this PR so that the wk_b and wk_v tensors are missing. In this case MLA will be disabled even if requested on the command line
  • Eliminated some unnecessary copies (ggml_cont). This repo has supported non-contiguous RoPE for a while and con-contiguous RMS norm on CUDA was added in cuda: non-contiguous rms norm #190 (the CPU has always supported non-contiguous RMS norm).

@saood06
Copy link
Collaborator

saood06 commented Feb 8, 2025

There were some other change's in the gguf-py/gguf/tensor_mapping.py that are in saood06#1 that I missed porting over earlier.

The next thing I was going to do was remove the old KV from being allocated, I hadn't gotten around to it, as I had a workaround from the mmap KV cache feature, but it should be a relatively simple fix, when I have more time I'll look into it.

@saood06
Copy link
Collaborator

saood06 commented Feb 8, 2025

@ikawrakow I made #195 to merge into this with the things mentioned.

saood06 and others added 3 commits February 9, 2025 09:36
* Avoid allocating MHA KV cache when MLA is turned on

* Added missing gguf-py file

* Added final optimizations

Co-authored-by: Stanisław Szymczyk <[email protected]>

* Make sure we do have wk_b and wv_b before enabling MLA

---------

Co-authored-by: Stanisław Szymczyk <[email protected]>
Co-authored-by: Iwan Kawrakow <[email protected]>
They were hard-coded at f16.
On my Ryzen-7950X with native bf16 support I get a fairly
significant PP performance boost with bf16 KV-cache:
PP-4096 = 320 t/s up from 292 t/s with fp16 KV-cache.
It gives a ~10% PP performance boost for DeepSeek-Lite with 32 threads
(with or without MLA).
Before this commit, when nth > nhead heads were processed
sequentially with all nth threads participating in each
matrix multiplication. Now we ind the gcd of nhead and
nth and split threads into nth/gcd groups, each group
processing nhead/gcd heads.
@ikawrakow
Copy link
Owner Author

I think we can merge this now.

@saood06 saood06 mentioned this pull request Feb 9, 2025
3 tasks
Copy link
Collaborator

@saood06 saood06 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, good catch on applying cache quantization, it was something I had missed. BF16 makes sense when it is faster, but I never bothered as I'm assuming it would come with a large quality loss.

Once this is merged I'll make PR's for the warmup MoE fix and then the mmap KV allocator .

Testing was a bit of a pain without the warmup MoE fix as loading in experts takes much longer (and it is already quite long as this server has no SSD only HDD) and takes many runs instead of just one warmup, PP seems slightly lower compared to my local testing branch but that might just be variance, or from the mmap KV allocator that I have yet to make a PR for.

@ikawrakow
Copy link
Owner Author

BF16 makes sense when it is faster, but I never bothered as I'm assuming it would come with a large quality loss.

Why? Most modern models are trained in bf16, so bf16 will be better than fp16. But if the CPU does not have native bf16 support it will be somewhat slower.

Once this is merged I'll make PR's for the warmup MoE fix and then the mmap KV allocator .

Sounds good.

@ikawrakow ikawrakow merged commit c12f73b into main Feb 9, 2025
@saood06
Copy link
Collaborator

saood06 commented Feb 9, 2025

BF16 makes sense when it is faster, but I never bothered as I'm assuming it would come with a large quality loss.

Why? Most modern models are trained in bf16, so bf16 will be better than fp16. But if the CPU does not have native bf16 support it will be somewhat slower.

I mispoke, I meant I never bothered quantizing the MLA version down to Q4 or Q6 as I did with the non MLA solution. I know most models are bf16 native (Deepseek was FP8 native which I had to upscale to BF16 before making the GGUF), and I would use BF16 if I had a modern processor with support for it.

The old solution was MHA, which quantizes down very well, and is large enough to warrant it. Heavy GQA does not, MLA is sized like heavy GQA and is also small enough where I'm fine leaving it in F16 and not smaller and not BF16 as my CPU is old and doesn't do BF16 well.

Comment on lines +3195 to +3207
// DeepSeek MLA
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
ggml_tensor * kr = ggml_new_tensor_1d(ctx, cache.type_kr, n_embd_head_qk_rope*kv_size);
ggml_tensor * kv = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
ggml_tensor * kvt = ggml_new_tensor_1d(ctx, cache.type_kv, kv_lora_rank*kv_size);
ggml_format_name(kr, "cache_kr_l%d", i);
ggml_format_name(kv, "cache_kv_l%d", i);
ggml_format_name(kvt, "cache_kvt_l%d", i);
cache.kr_l.push_back(kr);
cache.kv_l.push_back(kv);
cache.kvt_l.push_back(kvt);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed this, but I think this should be in the if block above as it is not needed for non MLA models.

Comment on lines +18058 to +18075
{
size_t memory_size_kr = 0;
size_t memory_size_kv = 0;

for (auto & kr : ctx->kv_self.kr_l) {
memory_size_kr += ggml_nbytes(kr);
}

for (auto & kv : ctx->kv_self.kv_l) {
memory_size_kv += ggml_nbytes(kv);
}

LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K^R (%s): %7.2f MiB, c^KV (%s): %7.2f MiB\n", __func__,
(float)(memory_size_kr + memory_size_kv) / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_kr / (1024.0f * 1024.0f),
ggml_type_name(type_k), (float)memory_size_kv / (1024.0f * 1024.0f));
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the above change only one of these should be allocated so that is the only one that should be displayed as KV self size

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants