Skip to content

Support Cohere Command-A (Cohere2ForCausalLM arch) #2912

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
aikitoria opened this issue Mar 14, 2025 · 41 comments
Open

Support Cohere Command-A (Cohere2ForCausalLM arch) #2912

aikitoria opened this issue Mar 14, 2025 · 41 comments
Assignees
Labels
Investigating KV-Cache Management triaged Issue has been triaged by maintainers

Comments

@aikitoria
Copy link

It would be great to support this new model! https://cohere.com/blog/command-a

They use a fairly unique architecture, where some layers use sliding window attention while others use global attention with no position embeddings, so even though I read through the documentation on how to add a model I'm a little lost on how to do this myself.

@user-0a
Copy link

user-0a commented Mar 14, 2025

+1, would also like support for this!

@aikitoria
Copy link
Author

aikitoria commented Mar 17, 2025

I've just finished implementing this, it appears to be working! I've also implemented support for FP8 Rowwise quants for the Cohere models. This was only missing a small piece where FP8 support was implemented for RmsNorm CUDA kernels but not LayerNorm ones.

Will open a PR once I ran a few more tests.

@aikitoria
Copy link
Author

aikitoria commented Mar 17, 2025

Hmm, there are problems with kv cache reuse. If I generate something that exceeds the sliding window size, and then generate something else that shares a prefix, it generates total garbage. Perhaps the implementation of different sliding windows per layer is not compatible with paged cache and cache reuse.

Luckily, nvidia recently open-sourced executor and batch_manager, which means we can just... fix it!

@aikitoria
Copy link
Author

aikitoria commented Mar 17, 2025

I've put the WIP branch here, it is not completely finished (i.e. does not yet load sliding window layers from config) but it can compile and run the model successfully: https://github.com/aikitoria/TensorRT-LLM/tree/experiments

Still haven't figured out how to fix the KV cache reuse with the sliding window though. And the changes to the LayernormQuantization plugin might have broken other models

Right now you have to add enable_block_reuse=False, max_attention_window=[4096, 4096, 4096, 131072] to KvCacheConfig for the model to not produce garbage, which is terrible, I will try to fix it but this part of the code is very complex and the iteration time is absolutely horrible due to huge templates

@aikitoria
Copy link
Author

aikitoria commented Mar 19, 2025

Ok, after looking into it more, I'm pretty sure it's IMPOSSIBLE for us to fix this. To allow block reuse to work, we need to not use a cyclic kv cache, instead we need to keep kv cache for the entire sequence, and only limit the window considered during attention. However, trtllm currently always uses cyclic kv cache for each layer where a sliding window is enabled, even though this is not actually useful at all for this model as it also has layers without sliding windows. The early kv blocks being overwritten for those layers then corrupts the blocks for reusing them across different requests*.

Fixing this would be possible if we could edit the kernels. However we are missing the code for the context phase kernel, there is only cubin stuff here: https://github.com/NVIDIA/TensorRT-LLM/tree/main/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention

So there's nothing we can do.

@byshiue can you guys please either release the code so we can fix it or officially support Cohere2ForCausalLM including block reuse (and the full sequence length, not just one sliding window like was done for Gemma 2)? The model runs fine in vLLM but trtllm runs this model more than two times faster so this is really frustrating that we can't make it work fully.

*I might still be misunderstanding the exact issue, but it's difficult to investigate further when there is no code.

@ming-wei
Copy link
Collaborator

Thanks for raising the issue.

We are aware of the garbage result issue when kv cache reuse and sliding window attention both enabled. We are on it right now.

@ming-wei ming-wei self-assigned this Mar 19, 2025
@aikitoria
Copy link
Author

aikitoria commented Mar 19, 2025

@ming-wei That would be great, thank you!

From exploring batch_manager code, it looks like someone had the idea of solving it by disabling block reuse if a sliding window is used, but it isn't working because it compares to the max num tokens rather than the smallest window size. Regardless, please don't let that become the solution!

It could fully support block reuse on this model if only there was a way to enable sliding windows without using cyclic kv cache at the same time. Adding this to the generation stage kernel was pretty easy in the end (keep the loop as it was, but add window start offset to timeslots when accessing kv) but I can't modify context one because it's missing the code.

@zhaocc1106
Copy link

It's exciting for your work for commanda. But i found some similar promblem for sliding window attention when support gemma3 text llm. I found attention mask of sliding window attention is not a pure causal type, which result in error result for long input (longer than sliding window) if i use AttentionMaskType.causal type . Especially for gemma3 with sliding window only 512.

@aikitoria
Copy link
Author

aikitoria commented Mar 19, 2025

It is difficult to tell exactly what goes wrong in the context kernel as we don't have the code, it seemed to me like it is actually working properly when passed attention_mask_type = ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL to fmha params in combination with cyclic kv cache. But I don't want cyclic kv cache so block reuse works.

@zhaocc1106
Copy link

Thanks for answering. But have you tested with long input which >4096(the window)? It’s easy to found the issue because gemma3 window only 512. I test gemm2 with long input which > 4096 (the window). It seems the phenomenon is the same.

@aikitoria
Copy link
Author

I tried some longer inputs and they were working but I didn't do super intensive testing yet

@zhaocc1106
Copy link

How do you set the attention mask to sliding window? I found the attention_mask_type you set is casual not sliding window causal, as following:
https://github.com/aikitoria/TensorRT-LLM/blob/cc798be8f9f672ab7751cbc38deb94b5bdad8e64/tensorrt_llm/models/commanda/model.py#L43
I found the AttentionMaskType.sliding_window_causal not be supported properly as comment: #2880 (comment)

Thanks~

@aikitoria
Copy link
Author

aikitoria commented Mar 19, 2025

Good point, that should probably be fixed in my code, but It automatically changes the mask type to sliding window here so it should be fine https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp#L300

And the generation phase kernel does not have this mask type argument (it is purely controlled by the sliding window size)

@zhaocc1106
Copy link

Thanks! i will continue analyze.

@zhaocc1106
Copy link

zhaocc1106 commented Mar 19, 2025

In addition, i did not found where to set the sliding window size when customized ourself model by py api. Could you tell me how do that properly? Thanks~

@aikitoria
Copy link
Author

aikitoria commented Mar 19, 2025

I did not find a way to define that as part of the model specification, instead I am currently passing it through from serve.py like so:

    kv_cache_config = KvCacheConfig(
        free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction,enable_block_reuse=True,max_attention_window=[4096, 4096, 4096, 131072])

However this causes cyclic kv cache
For my later experiments to modify the kernels I did not set this and instead hacked up attentionOp.cpp and decoderMaskedMultiheadAttentionTemplate.h to override it for specific layers, which appeared to be working for generating sequences beyond the window size, but I can't fix context beyond window size
I didn't push this code because it isn't working

@zhaocc1106
Copy link

Thanks very much! It seems i fix error output for long input by

kv_cache_config = KvCacheConfig(
        free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction,enable_block_reuse=True,max_attention_window=[4096, 4096, 4096, 131072])

I will pay attention to the block reuse issue you noticed.

@aikitoria
Copy link
Author

aikitoria commented Mar 20, 2025

To reproduce the issue, generate a long prompt that exceeds attention window size, then generate another prompt that shares a few hundred tokens from the start of the previous one (such as a common system message)

The output will be garbage. You can work around by disabling block reuse of course, but I wouldn't consider that usable in practice...

@zhaocc1106
Copy link

Alright, i did not set kv cache reuse feature now.

@zhaocc1106
Copy link

Image I just reproduce the garbage output with two times same long input when enable kv cache reuse.

@ming-wei
Copy link
Collaborator

Let me try to clarify a bit. We are working on a (somewhat complicated) solution to support alternating sliding window attention + kv cache reuse scenario. By "alternating sliding window attention" I'm referring to models with heterogeneous sliding_window_size across layers, e.g. cohere command A or gemma 2/3.

It would involve multiple components in TensorRT-LLM:

  • BlockManager per window size
    Nowadays, we assume all KV blocks across layers have the same sliding_window_size in BlockManager. This assumption doesn't hold in real life however. As a workaround, we take sliding_window_size to be the max sliding_window_size among all layers. However, a lot of memory is wasted, since we have to keep a cyclic kv cache buffer of max_sliding_window_size for all layers, rather than on a per-layer basis.

    To resolve this issue, we are working on a per-window-size BlockManager solution, which assigns layers of different sliding_window_size to different pools for a finer granularity control. With this change, people can expect lower memory usage on models like cohere command A/gemma 3.

  • Disable cyclic KV cache
    As OP pointed out in Support Cohere Command-A (Cohere2ForCausalLM arch) #2912 (comment), cyclic kv cache prevents KV cache reuse. Disabling cyclic KV cache would entail two parts, the runtime part and the kernel part. The runtime part is about BlockManager: all cyclic logic should be removed from BlockManager. The kernel part is about the GPU kernel for carrying out the actual MHA computation, including the context (prefill) phase FMHA and the generation (decode) phase XQA/MMHA.

  • Offload out-of-window blocks to host memory
    An unpleasant outcome of disabling cyclic KV cache is the excessive amount of memory for storing the whole KV cache. Fortunately, BlockManager has the ability to offload blocks to host memory when needed. We just need to enhance it a bit to offload blocks to host memory as soon as it goes out of sliding window.

I hope this helps clarify things. Let me know if it is still unclear, we are always happy to discuss more :)

Thanks,
Ming Wei on behalf of TensorRT-LLM team

@github-actions github-actions bot added triaged Issue has been triaged by maintainers Investigating labels Mar 20, 2025
@aikitoria
Copy link
Author

aikitoria commented Mar 20, 2025

Awesome, that sounds really good!

BlockManager per window size

Am I understanding it correctly that in the case where we disable cyclic kv cache for block reuse to work, this isn't actually necessary? As we have to keep the entire sequence of every layer regardless? So that's why I did not think this would be part of the solution, but I can see how it would be useful if block reuse is disabled.

Or is it mainly for making the offload logic work better? Since theoretically block reuse would need to fetch the whole sequence for the global layers, but the window layers would only need to fetch the last blocks covering the window size?

@ming-wei
Copy link
Collaborator

You are right about that. If we don't care device memory saving and offloading blocks to host, "BlockManager per window size" is not needed at all. We could simply keep the entire sequence in device memory, for all layers, regardless.

However, "BlockManager per window size" is a prerequisite of "offloading out-of-window blocks to host memory". Without "BlockManager per window size", a block is considered out of window iff. it falls out of max_sliding_window_size (max of sliding_window_size for all layers). In the very unfortunate case where full attention layer presents in the model, it effectively rules out any possibility of any block going out of window, thus no block would be offloaded to host at all.

The root cause is that current BlockManager doesn't have control at the layer granularity: for a given index position $i$ in the KV cache, the KV entry at $i$ from different layers actually belong to one single KV block. So the KV entries in sliding window layers are "encumbered" with KV entries in the full attention layer: since the latter cannot be offloaded, the former cannot either.

With "BlockManager per window size", KV entries in sliding window layers and that from full attention layers are decoupled, which helps offloading.

@aikitoria
Copy link
Author

I see, that makes sense. I have another suggestion on the topic, can there be a way to force specific sliding windows through the model specification? For example, for Command-A, we need 3/4 of the layers to be exactly 4096, so it is pretty clunky/error prone if the user has to manually specific this through the kv cache config on running the model rather than it... just doing that when this model is used.

@ming-wei
Copy link
Collaborator

Thanks for raising the "sliding window in kv cache config" concern. We'll think about it.

@aikitoria
Copy link
Author

aikitoria commented Mar 20, 2025

Btw, would it be possible to open source the kernels for FMHA, XQA, and "trtllmGen" ? Since MMHA is already open source, that would expose the whole library for people to modify according to the needs of their desired model. It's a bit confusing that we only have cubins for some critical pieces despite nvidia saying the library is open source.

It's not like keeping these kernels closed represents any meaningful moat for nvidia right? At the end of the day they're implemented with cuda and people will use nvidia gpus to run it regardless of whether the code is available.

@ming-wei
Copy link
Collaborator

We don't have plans to open source these kernels for now.

We will keep eye on it and consider the possibility of opening source kernels once we find it appropriate.

@netanel-haber
Copy link
Collaborator

netanel-haber commented Mar 24, 2025

I'm one of the devs working on improving (Variable) Sliding Window Attention [(V)SWA] (with and without reuse) support in TRTLLM.
Addressing some your points:

  1. because it compares to the max num tokens rather than the smallest window size

    This PR is meant to partially solve this issue, I'll update here when it's merged. Like explained above, the intention is to promptly remove the cyclic kv cache [@ming-wei and @tomeras91 are working on just that], but this is a bugfix in the meantime.

  2. I did not find a way to define that as part of the model specification

    This is a minor thing, but some months back we added a small feature that does allow specifying a default value of max_attention_window [and/or sink_token_length] in the model config, rather than having to specify it in the runtime, see RuntimeDefaults and this test for an example of using it in python.
    The feature isn't yet documented, and isn't enabled by default. It allows passing a RuntimeDefaults to PretrainedConfig, so every config.py can provide one by default if it wants to. So, in your branch:, for example:

            return Cohere2Config(
            architecture=hf_config.architectures[0],
            ...,
            runtime_defaults=RuntimeDefaults(max_attention_window=[4096, 4096, 4096, 131072])
            **kwargs)

    There is no magic here, the config just reads an optional runtime_defaults when the runtime is initialized, you can track it through the code. You would be the first to consume this feature. Please let me know if you encounter any issues! Notice that whether or not you use this, the cyclic cache is still the implementation of the kv cache for now.

  3. BlockManager per window size

    This has been a WIP for a long time, and I hope it makes it in in the near future, but I can't promise when. I'll update of course when it does make it in. It should finally enable real memory saving on Gemma2 (and especially Gemma3) etc.

@b8zhong
Copy link

b8zhong commented Mar 24, 2025

Hi @aikitoria, I know this is unrelated -- do you have the code anywhere to get it running on vLLM?

Many thanks!

@aikitoria
Copy link
Author

@b8zhong I don't believe any special code is required? As the model architecture is supported by vLLM, you can simply launch the model with vLLM and it should work. Though I have uploaded my personal vLLM quant here: https://huggingface.co/aikitoria/c4ai-command-a-03-2025-FP8-Dynamic

@aikitoria
Copy link
Author

RuntimeDefaults

Very nice, I did not spot that feature before, probably because nothing else needed it like you said. Will update the branch to use it and see if it works

@aikitoria
Copy link
Author

Did you guys just transition to a new github based workflow? I am noticing a lot more PRs and small commits being pushed rather than the monolithic ones we saw before.

@ming-wei
Copy link
Collaborator

@aikitoria Yes, we are transiting our daily work to github. There will be no "monolithic" PRs anymore.

@netanel-haber
Copy link
Collaborator

@aikitoria - updating that what I called point #1 - taking the minAttentionWindow into account, was merged into main: #2983

@aikitoria
Copy link
Author

@netanel-haber it looks like specifying RuntimeDefaults in the config causes it to fail to save the json after conversion:

Traceback (most recent call last):
  File "/code/tensorrt_llm/examples/commanda/convert_checkpoint.py", line 188, in execute
    future.result()
  File "/usr/lib/python3.12/concurrent/futures/_base.py", line 449, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
  File "/usr/lib/python3.12/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/code/tensorrt_llm/examples/commanda/convert_checkpoint.py", line 171, in convert_and_save_rank
    Cohere2.save_checkpoint(args.output_dir, save_config=(rank == 0))
  File "/code/tensorrt_llm/tensorrt_llm/models/modeling_utils.py", line 779, in save_checkpoint
    self.config.to_json_file(os.path.join(output_dir, 'config.json'))
  File "/code/tensorrt_llm/tensorrt_llm/models/modeling_utils.py", line 484, in to_json_file
    json.dump(self.to_dict(), f, indent=4)
              ^^^^^^^^^^^^^^
  File "/code/tensorrt_llm/tensorrt_llm/models/commanda/config.py", line 41, in to_dict
    output = super().to_dict()
             ^^^^^^^^^^^^^^^^^
  File "/code/tensorrt_llm/tensorrt_llm/models/modeling_utils.py", line 453, in to_dict
    output = copy.deepcopy(self.__dict__)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 136, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 221, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/copy.py", line 151, in deepcopy
    rv = reductor(4)
         ^^^^^^^^^^^
TypeError: cannot pickle 'tensorrt_llm.bindings.executor.RuntimeDefaults' object
Traceback (most recent call last):
  File "/code/tensorrt_llm/examples/commanda/convert_checkpoint.py", line 215, in <module>
    main()
  File "/code/tensorrt_llm/examples/commanda/convert_checkpoint.py", line 207, in main
    convert_and_save_hf(args)
  File "/code/tensorrt_llm/examples/commanda/convert_checkpoint.py", line 174, in convert_and_save_hf
    execute(args.workers, [convert_and_save_rank] * world_size, args)
  File "/code/tensorrt_llm/examples/commanda/convert_checkpoint.py", line 192, in execute
    assert len(
           ^^^^
AssertionError: Checkpoint conversion failed, please check error log.

@aikitoria
Copy link
Author

I bodged the PretrainedConfig to keep it stored as a dict rather than the native type, which allowed it to convert successfully, however now there is an error on running the engine:

[TRT-LLM] [E] Failed to parse the arguments for the LLM constructor: 'KvCacheConfig' object has no attribute 'fill_empty_fields_from_runtime_defaults'
Traceback (most recent call last):
  File "/usr/local/bin/trtllm-serve", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/click/core.py", line 1078, in main
    rv = self.invoke(ctx)
         ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/code/tensorrt_llm/tensorrt_llm/commands/serve.py", line 150, in main
    llm = LLM(**llm_args)
          ^^^^^^^^^^^^^^^
  File "/code/tensorrt_llm/tensorrt_llm/llmapi/llm.py", line 128, in __init__
    raise e
  File "/code/tensorrt_llm/tensorrt_llm/llmapi/llm.py", line 112, in __init__
    self.args = LlmArgs.from_kwargs(
                ^^^^^^^^^^^^^^^^^^^^
  File "/code/tensorrt_llm/tensorrt_llm/llmapi/llm_args.py", line 961, in from_kwargs
    ret._setup()
  File "/code/tensorrt_llm/tensorrt_llm/llmapi/llm_args.py", line 1026, in _setup
    self.kv_cache_config.fill_empty_fields_from_runtime_defaults(
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/pydantic/main.py", line 891, in __getattr__
    raise AttributeError(f'{type(self).__name__!r} object has no attribute {item!r}')
AttributeError: 'KvCacheConfig' object has no attribute 'fill_empty_fields_from_runtime_defaults'

@aikitoria
Copy link
Author

From what I can tell, the issue here is that

@PybindMirror.mirror_pybind_fields(_KvCacheConfig)
class KvCacheConfig(BaseModel, PybindMirror):

Only exposes properties from the underlying native bind class, not functions.

@aikitoria
Copy link
Author

aikitoria commented Mar 26, 2025

Ok I fixed the RuntimeDefaults not converting correctly on my branch: https://github.com/aikitoria/TensorRT-LLM/tree/experiments/

Can confirm the kv cache is no longer corrupting with the new changes!

Still need to load the sliding window sizes from the original hf config, then I think this should be good to go for a PR, and hopefully the separate caches per window will make it better soon! Unless there are other changes you would like to see?

@juney-nvidia
Copy link
Collaborator

@aikitoria Hi, thanks for working to add new model support into TensorRT-LLM. Pls go ahead with creating a PR. Since now TensorRT-LLM already becomes github firstly, the interactions with your PR should be even smoother. Let's iterate to ensure your contributions can receive more concrete feedbacks earlier to be more efficient.

Thanks
June

@aikitoria
Copy link
Author

Ok I have split it into two PRs and posted them here:
#3127
#3128

@aikitoria
Copy link
Author

Hmm I have been experimenting more with this branch today and noticed I am still getting kv cache corruption. However it's not entirely clear why, after restarting the server it now seems fine again. Will report if I find a reliable reproduction case

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Investigating KV-Cache Management triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

7 participants