Skip to content

MLA: allow Q8_0 K-cache for MLA #206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 13, 2025
Merged

MLA: allow Q8_0 K-cache for MLA #206

merged 1 commit into from
Feb 13, 2025

Conversation

ikawrakow
Copy link
Owner

After PR #205 we have two KV caches left when using MLA:

  • kv_l - contiguous, not transposed
  • kvt_l - a transposed version of kv_l

kv_l can be quantized, and this PR adds the necessary changes.
kvl_t, being a transposed version of kv_l, cannot be quantized. It can be eliminated by setting MLA_USE_TRANSPOSED_CACHE to 0 in llama.cpp (but then kv_l cannot be quantized as making a contiguous transposed tensor out of a quantized tensor as needed during inference does not work at this point).

Apart from reducing required KV cache memory, a quantized kv_l cache can also slightly improve TG performance after a long prompt. Here is a comparison between the main branch and this PR for tg64@ppN for different prompt lengths N. Model is IQ4_XS quantized DeepSeek-Lite. The results for the main branch are for fp16 kv_l and kvt_l cache, the PR used Q8_0 for kv_l and bf16 for kvt_l (using bf16 only makes sense for a CPU with native support, such as the Ryzen-7950X used to run the benchmark)

model test t/s (main) t/s (PR) Speedup
deepseek2 16B IQ4_XS tg64@pp128 33.80 ± 0.00 33.67 ± 0.01 0.996
deepseek2 16B IQ4_XS tg64@pp256 32.76 ± 0.01 33.55 ± 0.01 1.024
deepseek2 16B IQ4_XS tg64@pp512 32.68 ± 0.05 32.31 ± 0.00 0.989
deepseek2 16B IQ4_XS tg64@pp1024 32.02 ± 0.00 32.07 ± 0.00 1.002
deepseek2 16B IQ4_XS tg64@pp2048 30.31 ± 0.03 30.93 ± 0.00 1.020
deepseek2 16B IQ4_XS tg64@pp4096 27.54 ± 0.10 28.79 ± 0.07 1.045
deepseek2 16B IQ4_XS tg64@pp8192 23.12 ± 0.01 25.21 ± 0.02 1.090
deepseek2 16B IQ4_XS tg64@pp16384 18.74 ± 0.09 19.81 ± 0.05 1.057

@ikawrakow ikawrakow merged commit 8e94b29 into main Feb 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants