Skip to content

Commit ffd87f2

Browse files
saood06Kawrakow
andauthored
Make prompt cache saving and restoring MLA aware (#497)
* Remove kv_l, kvt_l and just use k_l and v_l * Hopefully take care of missing V cache (MLA) * Fix save and restore when there is no V cache * Fix double print * Update write_kv_cache_data and read_kv_cache_data to be MLA aware --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent eded4e2 commit ffd87f2

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

src/llama.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21448,13 +21448,15 @@ struct llama_data_write {
2144821448
// Get whole range at a time
2144921449
for (uint32_t il = 0; il < n_layer; ++il) {
2145021450
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
21451+
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
21452+
const uint32_t kv_lora_rank = hparams.n_lora_kv;
2145121453

2145221454
// Write key type
2145321455
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
2145421456
write(&k_type_i, sizeof(k_type_i));
2145521457

2145621458
// Write row size of key
21457-
const uint64_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
21459+
const uint64_t k_size_row = (ctx->cparams.mla_attn == 0) ? ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa) : ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope);
2145821460
write(&k_size_row, sizeof(k_size_row));
2145921461

2146021462
// Read each range of cells of k_size length each into tmp_buf and write out
@@ -21758,6 +21760,9 @@ struct llama_data_read {
2175821760
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
2175921761
for (uint32_t il = 0; il < n_layer; ++il) {
2176021762
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
21763+
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
21764+
const uint32_t kv_lora_rank = hparams.n_lora_kv;
21765+
2176121766

2176221767
// Read type of key
2176321768
int32_t k_type_i_ref;
@@ -21771,7 +21776,7 @@ struct llama_data_read {
2177121776
// Read row size of key
2177221777
uint64_t k_size_row_ref;
2177321778
read_to(&k_size_row_ref, sizeof(k_size_row_ref));
21774-
const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
21779+
const uint64_t k_size_row = (ctx->cparams.mla_attn == 0) ? ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa) : ggml_row_size(kv_self.k_l[il]->type, kv_lora_rank + n_embd_head_qk_rope);
2177521780
if (k_size_row != k_size_row_ref) {
2177621781
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
2177721782
return false;

0 commit comments

Comments
 (0)