-
Notifications
You must be signed in to change notification settings - Fork 52
FlashMLA-3: the best of both worlds (CPU only) #273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
As we don't have a way to tell if a repacked quant has been modified, I had to remove the modification at the expense of a slight decrease in performance. This affects q8_0_r8, q8_KV_r8, q8_k_r8 on Zen4, and q4_0_r8 on ARM.
Clever idea to combine the best of both worlds, PP with
So reading closely it sounds like fwiw, before thinking, I just compiled and tried to run it with CUDA backend. While I hope to kick the tires on this with the intel 6980P tomorrow. Also that |
Yes, it is CPU only. Based on the above graphs, this is what I would recommend for CPU-only inference.
Strange. It does run correctly on my end. The unsupported FA variant (head sizes 576 and 512) gets run on the CPU. I tried and was surprised to see that performance for DeepSeek-Lite is only marginally lower compared to all attention computed on the CPU:
In comparison, the same command but using |
Would it be possible to use FA for PP and no FA for TG as that would be the best of both worlds for my AVX-2 system? Did some testing to get a baseline to later compare against the HugePage mmap version, and PP is the best I've seen for IQ4_K_R4 when FA is turned on (IQ4_K seems like it would still perform better given I had gotten 11.5 t/s before MLA was implemented but I don't have that quant anymore, and still not sure why it performed better than IQ4_K_R4 especially now that I've seen others use the repacked quants without this issue). Results with FA off: Results with FA on (first PP result can be ignored as there was still some model loading since I saw disk activity): |
I think it is the number of threads that you are using that leads to a lower TG performance. The efficient path is not taken when the number of threads is not a power of 2. Can you try TG with 32 threads to confirm before I try to make changes? |
I already had ran some tests with 16,24,32,48 threads with FA on, results below but this is without dropping the caches like I normally do before changing thread counts (which is why I think the TG is 2.61 instead of 2.75).
Sorry, won't be available to run more tests till tommorow. |
The difference between dropping and not dropping caches is almost the same as the difference between FA off and FA on? Hope we are not chasing our tale here. But when you come around to test again, I recommend to try |
Here some results for all combinations of
|
Looking at your results with FA off, MLA-3 is similar to the lower TG of MLA-2 and not the faster MLA-1, with FA MLA-3 is similar to the faster MLA-1. Is that what is expected?
I know it wasn't a good test to show TG performance at 32 threads, that test was done to check the performance at 16 threads, and to get more insight into the behavior from not dropping the caches when changing thread count since I've known it's bad but haven't done enough testing to understand the variation in severity of the impact of it. The model takes 20-30 minutes to load in depending on thread count (with higher thread count taking longer). Interestingly PP performance seems to be unaffected by not dropping the cache as the values at 32 and 48 threads match the results with dropping the cache.
I ran more tests (new tests run on commit 3d6e25c ) and put the results (including the 48 thread results from above) in a table for easy viewing.
** I calculated these values removing the run where there was disk activity causing low performance No results for q8_KV with FA on as it crashed hitting this assert As you can see the best result for TG of those tested is still 48 threads with FA off and f16 type_k, and for PP it is also 48 threads but with FA on and f16 type_k. Going to q8_0 or q8_KV did help slightly when tested with 32 threads. PP performance at 32 threads is inline with my testing without dropping the cache where it performed far worse than all other tested thread counts, not really sure why that is, so even if 32 threads was ideal for TG it would come at a steep penalty for PP.
I know tg128 is not the best test, I prefer to do longer tests, and also test deeper into the KV cache but I was just planning to grab a baseline to see if the HugePage mmap changes can get anywhere close to the +50% TG uplift orca-zhang saw on his machine. Also #240 you reported FA degraded MLA-1 performance on AVX2, which is what made me test FA on and off (although I was surprised by seeing a difference with just tg128 as your results both here and there), I forgot that you improved that with #243, but as shown above the situation I see is different (could it be because of the size of the model?). |
Yes. With FA off, for TG MLA-3 is identical to MLA-2. With FA on, it is identical to MLA-1. |
Ran MLA-3 with FA through a much longer test via sweep-bench, will do the other 5 combinations as well.
The results are not ideal because of the issue with the TG performance often dropping lower but this is something I've experienced many times before with llama-server as well where I would workaround it by just canceling generation and sending requests until it wouldn't hit this issue. This bug seems like it's because it is bouncing around threads and thus resulting in lower CPU usage as I think I saw that when watching btop while it happened, but I may be wrong. |
Here are all 6 configurations (all at 48 threads with fmoe turned on) graphed. The MLA-3 FA on results are only up to 13312 while all other results are up to 15872. MLA-3 FA on configuration (excluding the strange bug) does seem like the best of both worlds even before #277 as it matches the strongest performing configuration in both PP and TG. Raw results:
MLA-1 FA off
MLA-2 FA on
MLA-2 FA off
MLA-3 FA on (only tested to 13312)
MLA-3 FA off
|
For DeepSeek models
mla=1
has a very good TG but low PP performance.mla=2
has better PP performance, but TG performance rapidly decreases with number of tokens in the KV cache.mla=0
(i.e., standard attention) has the best PP performance, but TG is even lower thanmla=2
. In addition, standard attention requires a much larger KV cache thanmla = 1,2
. Here are two graphs comparing PP and TG performance ofmla=0,1,2
for DeepSeek-Lite. In all cases FA is enabled, the KV cache is quantized withQ8_0
, the model weights are quantized withIQ4_NL
, and the calculations are run on a Ryzen-7950X CPU. The second graph is TG speed as a function of the number of tokens in the KV cache (obtained usingllama-bench -gp Np,64
). Note the logarithmic x-axis for both graphs.Since
mla=1
andmla=2
use the same KV cache (actually, just K-cache asV
gets computed from the K-cache), we can take the best parts ofmla=1
andmla=2
, and createmla=3
, where prompt processing is done with themla=2
approach, while TG is performed withmla=1
.Why do we need yet another option? Simply because the CUDA backend does not support
mla=1
, and theggml
back-end is very opinionated about where operations should run, with its opinions often being difficult to predict. Hence, when building the graph with more than one compute backend available, one cannot easily predict if the operation(s) will be run on the CPU or on the other compute backend, so it is easier to just have another option for this that the user can turn on via command line arguments.Coming back to the above graphs,
mla=3
PP performance is given by the blue curve in the first graph, and TG performance by the red curve in the second graph.