@@ -574,6 +574,9 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
574
574
{ LLM_TENSOR_OUTPUT, " output" },
575
575
{ LLM_TENSOR_ATTN_NORM, " blk.%d.attn_norm" },
576
576
{ LLM_TENSOR_ATTN_QKV, " blk.%d.attn_qkv" },
577
+ { LLM_TENSOR_ATTN_Q, " blk.%d.attn_q" },
578
+ { LLM_TENSOR_ATTN_K, " blk.%d.attn_k" },
579
+ { LLM_TENSOR_ATTN_V, " blk.%d.attn_v" },
577
580
{ LLM_TENSOR_ATTN_OUT, " blk.%d.attn_output" },
578
581
{ LLM_TENSOR_FFN_DOWN, " blk.%d.ffn_down" },
579
582
{ LLM_TENSOR_FFN_UP, " blk.%d.ffn_up" },
@@ -3676,8 +3679,19 @@ static bool llm_load_tensors(
3676
3679
layer.attn_norm = ml.create_tensor (ctx_layer, tn (LLM_TENSOR_ATTN_NORM, " weight" , i), {n_embd});
3677
3680
layer.attn_norm_b = ml.create_tensor (ctx_layer, tn (LLM_TENSOR_ATTN_NORM, " bias" , i), {n_embd});
3678
3681
3679
- layer.wqkv = ml.create_tensor (ctx_split, tn (LLM_TENSOR_ATTN_QKV, " weight" , i), {n_embd, n_embd + 2 *n_embd_gqa});
3680
- layer.bqkv = ml.create_tensor (ctx_layer, tn (LLM_TENSOR_ATTN_QKV, " bias" , i), {n_embd + 2 *n_embd_gqa});
3682
+ layer.wqkv = ml.create_tensor (ctx_split, tn (LLM_TENSOR_ATTN_QKV, " weight" , i), {n_embd, n_embd + 2 *n_embd_gqa}, false );
3683
+ layer.bqkv = ml.create_tensor (ctx_layer, tn (LLM_TENSOR_ATTN_QKV, " bias" , i), {n_embd + 2 *n_embd_gqa}, false );
3684
+
3685
+ if (layer.wqkv == nullptr ) {
3686
+ layer.wq = ml.create_tensor (ctx_split, tn (LLM_TENSOR_ATTN_Q, " weight" , i), {n_embd, n_embd});
3687
+ layer.bq = ml.create_tensor (ctx_layer, tn (LLM_TENSOR_ATTN_Q, " bias" , i), {n_embd});
3688
+
3689
+ layer.wk = ml.create_tensor (ctx_split, tn (LLM_TENSOR_ATTN_K, " weight" , i), {n_embd, n_embd_gqa});
3690
+ layer.bk = ml.create_tensor (ctx_layer, tn (LLM_TENSOR_ATTN_K, " bias" , i), {n_embd_gqa});
3691
+
3692
+ layer.wv = ml.create_tensor (ctx_split, tn (LLM_TENSOR_ATTN_V, " weight" , i), {n_embd, n_embd_gqa});
3693
+ layer.bv = ml.create_tensor (ctx_layer, tn (LLM_TENSOR_ATTN_V, " bias" , i), {n_embd_gqa});
3694
+ }
3681
3695
3682
3696
layer.wo = ml.create_tensor (ctx_split, tn (LLM_TENSOR_ATTN_OUT, " weight" , i), {n_embd, n_embd});
3683
3697
layer.bo = ml.create_tensor (ctx_layer, tn (LLM_TENSOR_ATTN_OUT, " bias" , i), {n_embd});
@@ -5637,15 +5651,25 @@ struct llm_build_context {
5637
5651
5638
5652
// self-attention
5639
5653
{
5640
- cur = ggml_mul_mat (ctx0, model.layers [il].wqkv , attn_norm_output);
5641
- cb (cur, " wqkv" , il);
5654
+ struct ggml_tensor * Qcur = nullptr ;
5655
+ struct ggml_tensor * Kcur = nullptr ;
5656
+ struct ggml_tensor * Vcur = nullptr ;
5642
5657
5643
- cur = ggml_add (ctx0, cur, model.layers [il].bqkv );
5644
- cb (cur, " bqkv" , il);
5658
+ if (model.layers [il].wqkv ) {
5659
+ cur = ggml_mul_mat (ctx0, model.layers [il].wqkv , attn_norm_output);
5660
+ cb (cur, " wqkv" , il);
5645
5661
5646
- struct ggml_tensor * Qcur = ggml_cont (ctx0, ggml_view_2d (ctx0, cur, n_embd, n_tokens, cur->nb [1 ], 0 *sizeof (float )*(n_embd)));
5647
- struct ggml_tensor * Kcur = ggml_cont (ctx0, ggml_view_2d (ctx0, cur, n_embd_gqa, n_tokens, cur->nb [1 ], 1 *sizeof (float )*(n_embd)));
5648
- struct ggml_tensor * Vcur = ggml_cont (ctx0, ggml_view_2d (ctx0, cur, n_embd_gqa, n_tokens, cur->nb [1 ], 1 *sizeof (float )*(n_embd + n_embd_gqa)));
5662
+ cur = ggml_add (ctx0, cur, model.layers [il].bqkv );
5663
+ cb (cur, " bqkv" , il);
5664
+
5665
+ Qcur = ggml_cont (ctx0, ggml_view_2d (ctx0, cur, n_embd, n_tokens, cur->nb [1 ], 0 *sizeof (float )*(n_embd)));
5666
+ Kcur = ggml_cont (ctx0, ggml_view_2d (ctx0, cur, n_embd_gqa, n_tokens, cur->nb [1 ], 1 *sizeof (float )*(n_embd)));
5667
+ Vcur = ggml_cont (ctx0, ggml_view_2d (ctx0, cur, n_embd_gqa, n_tokens, cur->nb [1 ], 1 *sizeof (float )*(n_embd + n_embd_gqa)));
5668
+ } else {
5669
+ Qcur = ggml_add (ctx0, ggml_mul_mat (ctx0, model.layers [il].wq , attn_norm_output), model.layers [il].bq );
5670
+ Kcur = ggml_add (ctx0, ggml_mul_mat (ctx0, model.layers [il].wk , attn_norm_output), model.layers [il].bk );
5671
+ Vcur = ggml_add (ctx0, ggml_mul_mat (ctx0, model.layers [il].wv , attn_norm_output), model.layers [il].bv );
5672
+ }
5649
5673
5650
5674
cb (Qcur, " Qcur" , il);
5651
5675
cb (Kcur, " Kcur" , il);
0 commit comments