Skip to content

Refactoring of multi-head attention and support for KV caching #2061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

mseeger
Copy link
Contributor

@mseeger mseeger commented May 30, 2025

This continues from #1934 . I created a new branch, because the history of the previous one was messed up with a merge operation.

Adds abstraction for key-value caches, implements batched inference.

I am also adding two baseline KV caches, the default one from before (all KV are stored) and a last-recent one.

OK, this PR contains the following parts:

  • Small things: Start of layer hook in GPT.forward, skip_lm_head in GPT.forward. I need these for gradient computation, but also to put proper head models on top of the transformer. This is generally useful.
  • Refactoring of multi-head attention: This is needed in order to implement the KV cache abstraction in the way @t-vi suggested (in a phone call). But it also really simplifies things. It also removes a major issue: mask_cache requires lots of memory, it is now computed on demand, with particular attention to inference (where query is much smaller than key)
  • Proper KV cache abstraction, which modifies slightly how GPT.forward is called (namely, input_pos as int). This simplifies things, though. I also provide a few default implementations. DenseKVCache replicates what is currently in place.

In the library I am writing, there are a number of additional more powerful KV caches, such as H2O and quantization-aware H2O. I am also working on fine-tuning in the presence of KV caches. The abstraction I propose here, enables all of that.

If these changes are not done, I'd have to copy and change quite a bit of your code. This would be hard to maintain, and would run the risk that KV caches are implemented differently at a later point, and then things really diverge.

As I said in the comments above, I found KV caching to be super-important to make large context inference work on a moderate GPU budget, which should be of interest to your customers as well.

@mseeger
Copy link
Contributor Author

mseeger commented May 30, 2025

Started work to make sure all tests pass.

@mseeger
Copy link
Contributor Author

mseeger commented May 30, 2025

@t-vi , @Borda , just a heads-up, I continue work in this PR, from #1934

@mseeger
Copy link
Contributor Author

mseeger commented May 30, 2025

Tests fail for me that should fail in mainline as well. For example, test_against_multimodal_gemma_3 in test_models.py fails in copy_weights_gemma_3, because the skip logic there checks for prefix "vision_tower" or "language_model", but the keys really start with "model.vision_tower" or "model.language_model".

??

@mseeger
Copy link
Contributor Author

mseeger commented May 30, 2025

I'll submit a PR with a fix.

@mseeger mseeger force-pushed the kvcache4 branch 9 times, most recently from 0546608 to 5442ea0 Compare June 6, 2025 08:37
@Borda
Copy link
Member

Borda commented Jun 10, 2025

I'll submit a PR with a fix.

Could you also link the PR here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants