-
Notifications
You must be signed in to change notification settings - Fork 53
Trellis quants with CPU inference #441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trellis quants with CPU inference #441
Conversation
Using 12 bits per 8 weights I get a better rmse than iq2_xxs. I still need to see how quantizing the group-of-8 scales will affect accuracy. By AVX2 SIMDifying the search for the best code, LLaMA-3.1-8B gets quantized in 130 seconds on the Ryzen-7950X CPU - sluggish but still acceptable.
rmse increases by just 3%, so this is beating iq2_xss in terms of rmse at the same 2.0625 bpw.
I now see that I was comparing apples to oranges: iq2_xxs was using a weight of sigma^2/4 + x^2, while the Trellis approach wasn't (weight = 1). Once I use the same weight, iq2_kt is actually slightly worse than iq2_xxs in terms of rmse, so does not look promising at this point. Also, once each group of 8 Trellis values no longer has a constant sum(q^2) that we can precompute, quantization becomes significantly slower (476 seconds for LLaMA-3.1-8B).
so we can run perplexity calcs. As already indicated by rmse, the 2-bit trellis approach is quite a bit worse than iq2_xxs.
With blocks of 32 and 16 bits per groups of 8 the brute force seach becomes prohibitive in terms of CPU time (30+ minutes for 8B LLaMA after SIMDifying with AVX2). The trick is to group the points in clusters, find the nearest cluster, and only search within the cluster.
Using blocks of 32 and 16 bits per group of 8 weights it beats iq2_xxs in terms of PPL by a significant margin. It is 0.0625 bpw larger, but even if we go to 15 bits per group od 8 (so 0.0625 bpw less than iq2_xxs), PPL is still lower.
Re-quantize after determining block scales (at the epxense of much longer quantization time).
Implemented as DMMV. Very slow - just 81 t/s for LLaMA-3.1-8B. Then again, Q2_K_S with forced to use DMMV only gets 112 t/s vs 145 t/s via MMVQ. My memory is that when the DMMV kernels were properly maintained/used, DMMV was about on par with MMVQ for k-quants on my GPU.
We arrive at 112 t/s.
We arrive at 139 t/s (no FA), and 149 t/s (FA). My RTX-4080 is ~20% slower than the RTX-6000 quoted in the QTIP repository, so with FA (which I'm sure they also used) we are at around ~180 t/s on their GPU, so almost matching their performance.
We arrive at 146 t/s (no FA), and 158 t/s (FA). This is measured for LLaMA-3.1-8B with output.weight left as f16.
3.125 bpw. So far does not look good on the PPL vs bpw plot.
PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.8322, which is starting to be competitive/slightly better than other quants.
PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.7892
PPL(LLaMA-3.1-8B-Instruct, 8192) is now 6.7689 after shrinking by 0.015 bpw by using iq4_k instead of q5_k for attn_v.
Nearly 60% improvement of quantization speed by having the points nelonging to a cluster copied to contiguous memory during initialization, and then accessed sequantially while searching for the closest point. LLaMA-3.1-8B now gets quantized in ~150 seconds on the Ryzen-5975WX.
Same trick as last commit applied to iq2_kt. Here we get an even larger speedup: quantization time on the Ryzen-5975WX for LLaMA-3.1-8B drops to 195 seconds from 375 seconds!
Is this in debug mode? I'm getting 10.4 t/s for |
I'm compiling with Here's how I'm testing:
Should I be using llama-bench or some other tool? |
I also tried
We get PP and TG performance as a function of the number of tokens in the KV cache |
Ok, well it's great to know the CPU inference performance is not totally unusable and that it's probably just my setup! I will try to figure this out on my own. Might email you some more questions to not pollute this PR discussion. Thanks also for the pointer on benchmarking. |
I purged my build directory + recompiled and performance is a lot better, and I no longer see the weird Now F16 gets almost 4x faster at 4.59 generation t/s, and IQ2_KT now beats F16 at 4.83 generation t/s for me. |
I did speed up
Overall it looks good to me, so we can think about merging. But there is also PR #435, where I have completely refactored |
Terrific, this gets my test machine to 5.59t/s. I saw the LCG ops in next8 taking up lots of time but wasn't sure what to do about it, this is a cool trick - I assume having the constants as locals keeps them in registers or otherwise ensures they remain hot in cache? Re: #435 - it looks not too difficult to me to reconcile my new kernels with the refactor. If you're done with your refactor already, you could merge your PR and then I can fix the resulting conflicts on this PR - maybe that's the cleanest way to do this? Since this branch is already conflicting with a file on main anyway. Otherwise happy to merge this first, then work on your branch. |
As requested a while ago, takes (#113) and adds CPU implementations of the quantized matmuls (via iqk_mul_mat) for inference. AVX2 and F16C support are required.
As predicted, the CPU ops are very slow. For Llama-3.1-8B-Instruct, I get
0.34.83 t/s with IQ2_KT compared to>1.04.59 t/s with F16 on AMD EPYC 7R32 (32 cores). Note I am not a SIMD expert and have only spent moderate time on optimizations (e.g. basic use of AVX2/F16C, flattening of the trellis iterations), so it may be possible to speed things up. I also have not added implementations forHAVE_FANCY_SIMD
. Additionally, there are only mulmats for F32 activations, as that is what the 3INST algorithm returns (as pointed out in the original PR description).I am not sure of the PR practices - if you'd like me to merge into #113 rather than the main branch, happy to change. I also tried to clean up some of the comments / dead code in the WIP branch, but can revert those changes as well.