Skip to content

Commit cf84811

Browse files
committed
Add cache quantization
1 parent fc5b39c commit cf84811

File tree

14 files changed

+1156
-16
lines changed

14 files changed

+1156
-16
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ This is an **early preview release** of ExLlamaV3. Please note: ↙
66
- The framework <u>is not yet fully optimized</u>. Performance is lacking, especially on Ampere, and there may be a significant CPU bottleneck on slower processors until the extension functions are fully built out.
77
- AMD GPUs (ROCm) are not yet supported.
88
- [FlashAttention-2](https://github.com/Dao-AILab/flash-attention) is currently required. I hope to switch over to [FlashInfer](https://github.com/flashinfer-ai/flashinfer/tree/main) in time, but there are some obstacles to overcome first.
9-
- A number of important features are yet to be added, such as cache quantization, tensor parallelism and multimodal support.
9+
- A number of important features are yet to be added, such as tensor parallelism and multimodal support.
1010
- There are no release builds yet.
1111
- Integration into [TabbyAPI](https://github.com/theroyallab/tabbyAPI/) is planned when all the core functionality is in place.
1212

@@ -26,7 +26,6 @@ There's much that still needs to be added and/or ported over from ExLlamaV2. I'v
2626
- Samplers (most notably repetition penalties and min-P are missing)
2727
- Constrained sampling (JSON filters etc.)
2828
- Multimodal support
29-
- Cache quantization
3029
- LoRA support
3130
- ROCm support
3231
- Tensor-parallel inference

doc/cq_humaneval.png

59.5 KB
Loading

exllamav3/cache/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .cache import Cache, CacheLayer
2-
from .fp16 import CacheLayer_fp16
2+
from .fp16 import CacheLayer_fp16
3+
from .quant import CacheLayer_quant

exllamav3/cache/cache.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(
1616
config: Config,
1717
attention: Attention,
1818
max_num_tokens: int,
19+
**kwargs
1920
):
2021
self.config = config
2122
self.attention = attention
@@ -30,7 +31,18 @@ def free(self):
3031
pass
3132

3233
@abstractmethod
33-
def get_kv(self):
34+
def get_kv(self, cache_seqlens: torch.Tensor, block_table: torch.Tensor) -> tuple:
35+
pass
36+
37+
@abstractmethod
38+
def update_kv(
39+
self,
40+
cache_seqlens: torch.Tensor,
41+
block_table: torch.Tensor,
42+
k: torch.Tensor,
43+
v: torch.Tensor,
44+
length: int
45+
):
3446
pass
3547

3648
@abstractmethod
@@ -45,6 +57,7 @@ def __init__(
4557
model: Model,
4658
max_num_tokens: int,
4759
layer_type: Type[CacheLayer] | None = None,
60+
**kwargs
4861
):
4962
"""
5063
Create cache for model
@@ -71,7 +84,7 @@ def __init__(
7184

7285
self.num_layers = len(self.model.get_cache_layers())
7386
self.layers = [
74-
self.layer_type(self.config, attn, self.max_num_tokens)
87+
self.layer_type(self.config, attn, self.max_num_tokens, **kwargs)
7588
for attn in self.model.get_cache_layers()
7689
]
7790
self.attach_to_model()
@@ -107,8 +120,20 @@ def detach_from_model(self, model: Model | None = None):
107120
module.cache_layers.remove(layer)
108121

109122

110-
def get_layer(self, idx: int) -> tuple:
111-
return self.layers[idx].get_kv()
123+
def get_layer(self, idx: int, cache_seqlens: torch.Tensor, block_table: torch.Tensor) -> tuple:
124+
return self.layers[idx].get_kv(cache_seqlens, block_table)
125+
126+
127+
def update_layer(
128+
self,
129+
idx: int,
130+
cache_seqlens: torch.Tensor,
131+
block_table: torch.Tensor,
132+
k: torch.Tensor,
133+
v: torch.Tensor,
134+
length: int
135+
):
136+
return self.layers[idx].update_kv(cache_seqlens, block_table, k, v, length)
112137

113138

114139
def copy_page(

exllamav3/cache/fp16.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,24 @@ def free(self):
4747

4848

4949
@override
50-
def get_kv(self):
50+
def get_kv(self, cache_seqlens: torch.Tensor, block_table: torch.Tensor) -> tuple:
5151
return self.k, self.v
5252

5353

54+
@override
55+
def update_kv(
56+
self,
57+
cache_seqlens: torch.Tensor,
58+
block_table: torch.Tensor,
59+
k: torch.Tensor,
60+
v: torch.Tensor,
61+
length: int
62+
):
63+
pass
64+
65+
5466
@override
5567
def copy_page(self, source: CacheLayer_fp16, from_page: int, to_page: int, num_tokens: int):
5668
assert self.shape == source.shape
57-
kd = self.k[to_page, :num_tokens, :, :]
58-
vd = self.v[to_page, :num_tokens, :, :]
59-
ks = source.k[from_page, :num_tokens, :, :]
60-
vs = source.v[from_page, :num_tokens, :, :]
61-
kd.copy_(ks, non_blocking = True)
62-
vd.copy_(vs, non_blocking = True)
69+
self.k[to_page, :num_tokens, :, :].copy_(source.k[from_page, :num_tokens, :, :], non_blocking = True)
70+
self.v[to_page, :num_tokens, :, :].copy_(source.v[from_page, :num_tokens, :, :], non_blocking = True)

exllamav3/cache/quant.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from __future__ import annotations
2+
from typing_extensions import override
3+
import torch
4+
import torch.nn.functional as F
5+
from torch import nn
6+
from ..constants import PAGE_SIZE
7+
from ..models import Model, Config
8+
from .cache import CacheLayer
9+
from typing import TYPE_CHECKING
10+
from exllamav3.ext import exllamav3_ext as ext
11+
if TYPE_CHECKING:
12+
from ..modules import Attention
13+
14+
class CacheLayer_quant(CacheLayer):
15+
16+
def __init__(
17+
self,
18+
config: Config,
19+
attention: Attention,
20+
max_num_tokens: int,
21+
k_bits: int,
22+
v_bits: int,
23+
):
24+
super().__init__(config, attention, max_num_tokens)
25+
26+
assert max_num_tokens % PAGE_SIZE == 0, \
27+
f"max_num_tokens must be a multiple of {PAGE_SIZE}."
28+
assert (2 <= k_bits <= 8) and (2 <= v_bits <= 8), "quantized cache must be from 2 to 8 bits"
29+
30+
self.shape = (
31+
(max_num_tokens // PAGE_SIZE, PAGE_SIZE, attention.num_kv_heads, attention.head_dim)
32+
if attention else None
33+
)
34+
35+
self.k_bits = k_bits
36+
self.v_bits = v_bits
37+
self.token_dim = attention.num_kv_heads * attention.head_dim
38+
self.qshape_k = ((max_num_tokens // PAGE_SIZE, PAGE_SIZE, self.token_dim // 32 * k_bits) if attention else None)
39+
self.qshape_v = ((max_num_tokens // PAGE_SIZE, PAGE_SIZE, self.token_dim // 32 * v_bits) if attention else None)
40+
self.qshape_s = ((max_num_tokens // PAGE_SIZE, PAGE_SIZE, self.token_dim // 32) if attention else None)
41+
42+
self.qk = None
43+
self.qv = None
44+
self.sk = None
45+
self.sv = None
46+
self.device = None
47+
48+
49+
@override
50+
def alloc(self, device: torch.device):
51+
self.device = device
52+
self.qk = torch.zeros(self.qshape_k, dtype = torch.int, device = device) if self.shape else None
53+
self.qv = torch.zeros(self.qshape_v, dtype = torch.int, device = device) if self.shape else None
54+
self.sk = torch.zeros(self.qshape_s, dtype = torch.half, device = device) if self.shape else None
55+
self.sv = torch.zeros(self.qshape_s, dtype = torch.half, device = device) if self.shape else None
56+
57+
58+
@override
59+
def free(self):
60+
self.device = None
61+
self.qk = None
62+
self.qv = None
63+
self.sk = None
64+
self.sv = None
65+
66+
67+
@override
68+
def get_kv(self, cache_seqlens: torch.Tensor, block_table: torch.Tensor):
69+
k = torch.empty(self.shape, dtype = torch.half, device = self.device)
70+
v = torch.empty(self.shape, dtype = torch.half, device = self.device)
71+
ext.dequant_cache_paged(self.qk, self.sk, k, self.qv, self.sv, v, cache_seqlens, block_table, PAGE_SIZE)
72+
return k, v
73+
74+
75+
@override
76+
def update_kv(
77+
self,
78+
cache_seqlens: torch.Tensor,
79+
block_table: torch.Tensor,
80+
k: torch.Tensor,
81+
v: torch.Tensor,
82+
length: int
83+
):
84+
ext.quant_cache_paged(
85+
k, self.qk, self.sk,
86+
v, self.qv, self.sv,
87+
cache_seqlens, block_table,
88+
PAGE_SIZE,
89+
length
90+
)
91+
92+
93+
@override
94+
def copy_page(self, source: CacheLayer_quant, from_page: int, to_page: int, num_tokens: int):
95+
assert self.qshape_k == source.qshape_k
96+
assert self.qshape_v == source.qshape_v
97+
self.qk[to_page, :num_tokens, :].copy_(source.qk[from_page, :num_tokens, :], non_blocking = True)
98+
self.qv[to_page, :num_tokens, :].copy_(source.qv[from_page, :num_tokens, :], non_blocking = True)
99+
self.sk[to_page, :num_tokens, :].copy_(source.sk[from_page, :num_tokens, :], non_blocking = True)
100+
self.sv[to_page, :num_tokens, :].copy_(source.sv[from_page, :num_tokens, :], non_blocking = True)

exllamav3/exllamav3_ext/bindings.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
#include "generator/sampling_basic.cuh"
2222
#include "generator/gumbel.cuh"
2323

24+
#include "cache/q_cache.cuh"
25+
2426
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
2527
{
2628
m.def("stloader_read", &stloader_read, "stloader_read");
@@ -56,4 +58,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
5658

5759
m.def("partial_strings_match", &partial_strings_match, "partial_strings_match");
5860
m.def("count_match_tensor", &count_match_tensor, "count_match_tensor");
61+
62+
m.def("quant_cache_cont", &quant_cache_cont, "quant_cache_cont");
63+
m.def("dequant_cache_cont", &dequant_cache_cont, "dequant_cache_cont");
64+
m.def("quant_cache_paged", &quant_cache_paged, "quant_cache_paged");
65+
m.def("dequant_cache_paged", &dequant_cache_paged, "dequant_cache_paged");
5966
}

0 commit comments

Comments
 (0)