-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Support Cohere Command-A (Cohere2ForCausalLM arch) #2912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
+1, would also like support for this! |
I've just finished implementing this, it appears to be working! I've also implemented support for FP8 Rowwise quants for the Cohere models. This was only missing a small piece where FP8 support was implemented for RmsNorm CUDA kernels but not LayerNorm ones. Will open a PR once I ran a few more tests. |
Hmm, there are problems with kv cache reuse. If I generate something that exceeds the sliding window size, and then generate something else that shares a prefix, it generates total garbage. Perhaps the implementation of different sliding windows per layer is not compatible with paged cache and cache reuse. Luckily, nvidia recently open-sourced executor and batch_manager, which means we can just... fix it! |
I've put the WIP branch here, it is not completely finished (i.e. does not yet load sliding window layers from config) but it can compile and run the model successfully: https://github.com/aikitoria/TensorRT-LLM/tree/experiments Still haven't figured out how to fix the KV cache reuse with the sliding window though. And the changes to the LayernormQuantization plugin might have broken other models Right now you have to add |
Ok, after looking into it more, I'm pretty sure it's IMPOSSIBLE for us to fix this. To allow block reuse to work, we need to not use a cyclic kv cache, instead we need to keep kv cache for the entire sequence, and only limit the window considered during attention. However, trtllm currently always uses cyclic kv cache for each layer where a sliding window is enabled, even though this is not actually useful at all for this model as it also has layers without sliding windows. The early kv blocks being overwritten for those layers then corrupts the blocks for reusing them across different requests*. Fixing this would be possible if we could edit the kernels. However we are missing the code for the context phase kernel, there is only cubin stuff here: https://github.com/NVIDIA/TensorRT-LLM/tree/main/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention So there's nothing we can do. @byshiue can you guys please either release the code so we can fix it or officially support Cohere2ForCausalLM including block reuse (and the full sequence length, not just one sliding window like was done for Gemma 2)? The model runs fine in vLLM but trtllm runs this model more than two times faster so this is really frustrating that we can't make it work fully. *I might still be misunderstanding the exact issue, but it's difficult to investigate further when there is no code. |
Thanks for raising the issue. We are aware of the garbage result issue when kv cache reuse and sliding window attention both enabled. We are on it right now. |
@ming-wei That would be great, thank you! From exploring batch_manager code, it looks like someone had the idea of solving it by disabling block reuse if a sliding window is used, but it isn't working because it compares to the max num tokens rather than the smallest window size. Regardless, please don't let that become the solution! It could fully support block reuse on this model if only there was a way to enable sliding windows without using cyclic kv cache at the same time. Adding this to the generation stage kernel was pretty easy in the end (keep the loop as it was, but add window start offset to timeslots when accessing kv) but I can't modify context one because it's missing the code. |
It's exciting for your work for commanda. But i found some similar promblem for sliding window attention when support gemma3 text llm. I found attention mask of sliding window attention is not a pure causal type, which result in error result for long input (longer than sliding window) if i use |
It is difficult to tell exactly what goes wrong in the context kernel as we don't have the code, it seemed to me like it is actually working properly when passed |
Thanks for answering. But have you tested with long input which >4096(the window)? It’s easy to found the issue because gemma3 window only 512. I test gemm2 with long input which > 4096 (the window). It seems the phenomenon is the same. |
I tried some longer inputs and they were working but I didn't do super intensive testing yet |
How do you set the attention mask to sliding window? I found the attention_mask_type you set is casual not sliding window causal, as following: Thanks~ |
Good point, that should probably be fixed in my code, but It automatically changes the mask type to sliding window here so it should be fine https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp#L300 And the generation phase kernel does not have this mask type argument (it is purely controlled by the sliding window size) |
Thanks! i will continue analyze. |
In addition, i did not found where to set the sliding window size when customized ourself model by py api. Could you tell me how do that properly? Thanks~ |
I did not find a way to define that as part of the model specification, instead I am currently passing it through from
However this causes cyclic kv cache |
Thanks very much! It seems i fix error output for long input by
I will pay attention to the block reuse issue you noticed. |
To reproduce the issue, generate a long prompt that exceeds attention window size, then generate another prompt that shares a few hundred tokens from the start of the previous one (such as a common system message) The output will be garbage. You can work around by disabling block reuse of course, but I wouldn't consider that usable in practice... |
Alright, i did not set kv cache reuse feature now. |
Let me try to clarify a bit. We are working on a (somewhat complicated) solution to support alternating sliding window attention + kv cache reuse scenario. By "alternating sliding window attention" I'm referring to models with heterogeneous sliding_window_size across layers, e.g. cohere command A or gemma 2/3. It would involve multiple components in TensorRT-LLM:
I hope this helps clarify things. Let me know if it is still unclear, we are always happy to discuss more :) Thanks, |
Awesome, that sounds really good!
Am I understanding it correctly that in the case where we disable cyclic kv cache for block reuse to work, this isn't actually necessary? As we have to keep the entire sequence of every layer regardless? So that's why I did not think this would be part of the solution, but I can see how it would be useful if block reuse is disabled. Or is it mainly for making the offload logic work better? Since theoretically block reuse would need to fetch the whole sequence for the global layers, but the window layers would only need to fetch the last blocks covering the window size? |
You are right about that. If we don't care device memory saving and offloading blocks to host, "BlockManager per window size" is not needed at all. We could simply keep the entire sequence in device memory, for all layers, regardless. However, "BlockManager per window size" is a prerequisite of "offloading out-of-window blocks to host memory". Without "BlockManager per window size", a block is considered out of window iff. it falls out of max_sliding_window_size (max of sliding_window_size for all layers). In the very unfortunate case where full attention layer presents in the model, it effectively rules out any possibility of any block going out of window, thus no block would be offloaded to host at all. The root cause is that current BlockManager doesn't have control at the layer granularity: for a given index position With "BlockManager per window size", KV entries in sliding window layers and that from full attention layers are decoupled, which helps offloading. |
I see, that makes sense. I have another suggestion on the topic, can there be a way to force specific sliding windows through the model specification? For example, for Command-A, we need 3/4 of the layers to be exactly 4096, so it is pretty clunky/error prone if the user has to manually specific this through the kv cache config on running the model rather than it... just doing that when this model is used. |
Thanks for raising the "sliding window in kv cache config" concern. We'll think about it. |
Btw, would it be possible to open source the kernels for FMHA, XQA, and "trtllmGen" ? Since MMHA is already open source, that would expose the whole library for people to modify according to the needs of their desired model. It's a bit confusing that we only have cubins for some critical pieces despite nvidia saying the library is open source. It's not like keeping these kernels closed represents any meaningful moat for nvidia right? At the end of the day they're implemented with cuda and people will use nvidia gpus to run it regardless of whether the code is available. |
We don't have plans to open source these kernels for now. We will keep eye on it and consider the possibility of opening source kernels once we find it appropriate. |
I'm one of the devs working on improving (Variable) Sliding Window Attention [(V)SWA] (with and without reuse) support in TRTLLM.
|
Hi @aikitoria, I know this is unrelated -- do you have the code anywhere to get it running on vLLM? Many thanks! |
@b8zhong I don't believe any special code is required? As the model architecture is supported by vLLM, you can simply launch the model with vLLM and it should work. Though I have uploaded my personal vLLM quant here: https://huggingface.co/aikitoria/c4ai-command-a-03-2025-FP8-Dynamic |
Very nice, I did not spot that feature before, probably because nothing else needed it like you said. Will update the branch to use it and see if it works |
Did you guys just transition to a new github based workflow? I am noticing a lot more PRs and small commits being pushed rather than the monolithic ones we saw before. |
@aikitoria Yes, we are transiting our daily work to github. There will be no "monolithic" PRs anymore. |
@aikitoria - updating that what I called point #1 - taking the |
@netanel-haber it looks like specifying RuntimeDefaults in the config causes it to fail to save the json after conversion:
|
I bodged the PretrainedConfig to keep it stored as a dict rather than the native type, which allowed it to convert successfully, however now there is an error on running the engine:
|
From what I can tell, the issue here is that
Only exposes properties from the underlying native bind class, not functions. |
Ok I fixed the RuntimeDefaults not converting correctly on my branch: https://github.com/aikitoria/TensorRT-LLM/tree/experiments/ Can confirm the kv cache is no longer corrupting with the new changes! Still need to load the sliding window sizes from the original hf config, then I think this should be good to go for a PR, and hopefully the separate caches per window will make it better soon! Unless there are other changes you would like to see? |
@aikitoria Hi, thanks for working to add new model support into TensorRT-LLM. Pls go ahead with creating a PR. Since now TensorRT-LLM already becomes github firstly, the interactions with your PR should be even smoother. Let's iterate to ensure your contributions can receive more concrete feedbacks earlier to be more efficient. Thanks |
Hmm I have been experimenting more with this branch today and noticed I am still getting kv cache corruption. However it's not entirely clear why, after restarting the server it now seems fine again. Will report if I find a reliable reproduction case |
It would be great to support this new model! https://cohere.com/blog/command-a
They use a fairly unique architecture, where some layers use sliding window attention while others use global attention with no position embeddings, so even though I read through the documentation on how to add a model I'm a little lost on how to do this myself.
The text was updated successfully, but these errors were encountered: