@@ -1862,9 +1862,6 @@ static bool llama_kv_cache_init(
1862
1862
if (model.arch == LLM_ARCH_MAMBA) {
1863
1863
// only one slot is needed for Mamba
1864
1864
n_ctx = 1 ;
1865
- // it's probably best to keep as much precision as possible for the states
1866
- ktype = GGML_TYPE_F32;
1867
- vtype = GGML_TYPE_F32;
1868
1865
}
1869
1866
1870
1867
cache.has_shift = false ;
@@ -4179,7 +4176,7 @@ static bool llm_load_tensors(
4179
4176
} break ;
4180
4177
case LLM_ARCH_MAMBA:
4181
4178
{
4182
- const int64_t d_conv = hparams.n_embd_head_k ;
4179
+ const int64_t d_conv = hparams.n_embd_head_k + 1 ;
4183
4180
const int64_t d_state = hparams.n_embd_head_v ;
4184
4181
const int64_t d_inner = hparams.n_head ;
4185
4182
// FIXME: ceiling instead of floor
@@ -6917,28 +6914,27 @@ struct llm_build_context {
6917
6914
struct ggml_cgraph * build_mamba () {
6918
6915
struct ggml_cgraph * gf = ggml_new_graph_custom (ctx0, LLAMA_MAX_NODES, false );
6919
6916
6920
- const bool use_conv = batch.n_tokens > 1 ;
6921
- GGML_ASSERT (use_conv == false ); // TODO: implement
6917
+ const int32_t n_tok = batch.n_tokens ;
6922
6918
6923
6919
// hopefully the compiler does constant folding
6924
6920
const int64_t d_model = n_embd;
6925
6921
const int64_t d_inner = n_head;
6926
6922
GGML_ASSERT (2 * d_model == d_inner);
6927
- const int64_t d_conv = n_embd_head_k;
6923
+ const int64_t d_conv = n_embd_head_k + 1 ;
6928
6924
const int64_t d_state = n_embd_head_v;
6929
6925
const int64_t dt_rank = d_model / 16 ;
6930
6926
6931
6927
struct ggml_tensor * cur;
6932
6928
struct ggml_tensor * inpL;
6933
6929
6934
- // NOTE: not sure what's the difference between the sequence length and the batch size in the paper.
6935
- // {n_embd, batch}
6930
+ // {n_embd, n_tok}
6936
6931
inpL = llm_build_inp_embd (ctx0, hparams, batch, model.tok_embd , lctx.inp_tokens , lctx.inp_embd , cb);
6937
6932
cb (inpL, " inp_embd" , -1 );
6938
6933
6939
6934
for (int il = 0 ; il < n_layer; ++il) {
6940
6935
// (ab)using the kv cache to store the state
6941
- ggml_tensor * conv_state = ggml_reshape_2d (ctx0, kv_self.k_l [il], d_conv, d_inner);
6936
+ // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
6937
+ ggml_tensor * conv_state = ggml_reshape_2d (ctx0, kv_self.k_l [il], d_conv - 1 , d_inner);
6942
6938
ggml_tensor * ssm_state = ggml_reshape_2d (ctx0, kv_self.v_l [il], d_state, d_inner);
6943
6939
6944
6940
// norm
@@ -6947,33 +6943,43 @@ struct llm_build_context {
6947
6943
LLM_NORM_RMS, cb, il);
6948
6944
cb (cur, " attn_norm" , il);
6949
6945
6950
- // {n_embd, 2*d_inner} * {n_embd, batch } = {2*d_inner, batch }
6946
+ // {n_embd, 2*d_inner} * {n_embd, n_tok } => {2*d_inner, n_tok }
6951
6947
struct ggml_tensor * xz = ggml_mul_mat (ctx0, model.layers [il].ssm_in , cur);
6952
6948
// split the above in two
6953
- // assuming it's contiguous
6954
- // {d_inner, batch}
6949
+ // => {d_inner, n_tok}
6955
6950
struct ggml_tensor * x = ggml_view_2d (ctx0, xz, d_inner, xz->ne [1 ], xz->nb [1 ], 0 );
6956
6951
struct ggml_tensor * z = ggml_view_2d (ctx0, xz, d_inner, xz->ne [1 ], xz->nb [1 ], ggml_element_size (xz)*d_inner);
6957
6952
6958
- cur = x;
6959
-
6960
6953
// conv
6961
6954
{
6962
- // shift conv state left
6963
- conv_state = ggml_set_2d (ctx0, conv_state, ggml_view_2d (ctx0, conv_state, (d_conv - 1 ), d_inner, conv_state->nb [1 ], ggml_element_size (conv_state)*1 ), conv_state->nb [1 ], 0 );
6964
-
6965
- // update last column
6966
- // x here is {d_inner, 1} (a row), but should be {1, d_inner} (a column)
6967
- conv_state = ggml_set_2d (ctx0, conv_state, ggml_cont (ctx0, ggml_transpose (ctx0, x)), conv_state->nb [1 ], ggml_element_size (conv_state)*(d_conv - 1 ));
6968
-
6969
- ggml_build_forward_expand (gf, ggml_cpy (ctx0, conv_state, ggml_view_tensor (ctx0, kv_self.k_l [il])));
6970
-
6971
- // rearrange and sum
6972
- // no need to rearrange the conv_state, since it's already in the right shape
6973
- // => {1, d_inner}
6974
- x = ggml_sum_rows (ctx0, ggml_mul (ctx0, conv_state, model.layers [il].ssm_conv1d ));
6975
- // => {d_inner, 1}
6976
- x = ggml_transpose (ctx0, x);
6955
+ // concat last (d_conv - 1) columns of conv_state, and x
6956
+
6957
+ // The following tensor is too big in order to avoid an assertion error when making an overlapping view.
6958
+ // TODO: in ggml_new_tensor_impl, handle overlapping data range in data size calculation
6959
+ // This could then be a tensor with ne[] = {(d_conv-1)+n_tok, d_inner}
6960
+ // which is around (d_conv-1) times as small as its current size.
6961
+ struct ggml_tensor * conv_x = ggml_new_tensor_1d (ctx0, conv_state->type , d_conv*d_inner*n_tok);
6962
+ const size_t conv_x_nb1 = (d_conv - 1 + n_tok) * ggml_element_size (conv_x);
6963
+
6964
+ conv_x = ggml_set_2d (ctx0, conv_x, conv_state, conv_x_nb1, 0 );
6965
+ // unfortunately, making x contiguous is necessary because ggml_set expects nb0 == sizeof(float)
6966
+ conv_x = ggml_set_2d (ctx0, conv_x, ggml_cont (ctx0, ggml_transpose (ctx0, x)), conv_x_nb1, (d_conv - 1 )*ggml_element_size (conv_x));
6967
+
6968
+ // store last (d_conv - 1) columns of conv_x back into the KV cache for the next conv_state
6969
+ ggml_build_forward_expand (gf,
6970
+ ggml_cpy (ctx0,
6971
+ ggml_view_2d (ctx0, conv_x, d_conv - 1 , d_inner, conv_x_nb1, n_tok*ggml_element_size (conv_x)),
6972
+ ggml_view_tensor (ctx0, kv_self.k_l [il])));
6973
+
6974
+ // prepare convolution for all tokens in the batch with a self-overlapping view
6975
+ // {(d_conv-1)+n_tok, d_inner} => {d_conv, d_inner, n_tok}
6976
+ conv_x = ggml_view_3d (ctx0, conv_x, d_conv, d_inner, n_tok, conv_x_nb1, -(d_conv - 1 )*d_inner*ggml_element_size (conv_x), 0 );
6977
+
6978
+ // perform convolution
6979
+ // => {1, d_inner, n_tok}
6980
+ x = ggml_sum_rows (ctx0, ggml_mul (ctx0, conv_x, model.layers [il].ssm_conv1d ));
6981
+ // => {d_inner, n_tok, 1}
6982
+ x = ggml_permute (ctx0, x, 2 , 0 , 1 , 3 );
6977
6983
6978
6984
// bias
6979
6985
x = ggml_add (ctx0, x, model.layers [il].ssm_conv1d_b );
@@ -6983,23 +6989,24 @@ struct llm_build_context {
6983
6989
6984
6990
// ssm
6985
6991
{
6986
- // {2*n_embd, batch} * {2*n_embd , dt_rank + 2*d_state} = {batch, dt_rank + 2*d_state}
6987
- struct ggml_tensor * x_db = ggml_mul_mat (ctx0, x, model.layers [il].ssm_x );
6988
- // FIXME: handle batches of more than 1 token
6989
- struct ggml_tensor * dt = ggml_view_1d (ctx0, x_db, dt_rank, 0 );
6990
- struct ggml_tensor * B = ggml_view_1d (ctx0, x_db, d_state, ggml_element_size (x_db)*dt_rank);
6991
- struct ggml_tensor * C = ggml_view_1d (ctx0, x_db, d_state, ggml_element_size (x_db)*(dt_rank+d_state));
6992
-
6993
- // {dt_rank} * {dt_rank, d_inner } = {1, d_inner }
6994
- dt = ggml_mul_mat (ctx0, dt, model.layers [il].ssm_dt );
6995
- dt = ggml_add (ctx0, dt, ggml_transpose (ctx0, model.layers [il].ssm_dt_b ) );
6992
+ // {d_inner , dt_rank + 2*d_state} * {d_inner, n_tok} => { dt_rank + 2*d_state, n_tok }
6993
+ struct ggml_tensor * x_db = ggml_mul_mat (ctx0, model.layers [il].ssm_x , x );
6994
+ // split
6995
+ struct ggml_tensor * dt = ggml_view_2d (ctx0, x_db, dt_rank, x_db-> ne [ 1 ], x_db-> nb [ 1 ] , 0 );
6996
+ struct ggml_tensor * B = ggml_view_2d (ctx0, x_db, d_state, x_db-> ne [ 1 ], x_db-> nb [ 1 ] , ggml_element_size (x_db)*dt_rank);
6997
+ struct ggml_tensor * C = ggml_view_2d (ctx0, x_db, d_state, x_db-> ne [ 1 ], x_db-> nb [ 1 ] , ggml_element_size (x_db)*(dt_rank+d_state));
6998
+
6999
+ // {dt_rank, d_inner } * {dt_rank, n_tok } => {d_inner, n_tok }
7000
+ dt = ggml_mul_mat (ctx0, model.layers [il].ssm_dt , dt );
7001
+ dt = ggml_add (ctx0, dt, model.layers [il].ssm_dt_b );
6996
7002
dt = ggml_soft_plus (ctx0, dt);
6997
7003
7004
+ // FIXME: support batches with more than 1 token
6998
7005
// => {d_state, d_inner}
6999
- struct ggml_tensor * dA = ggml_exp (ctx0, ggml_mul (ctx0, model.layers [il].ssm_a , dt ));
7006
+ struct ggml_tensor * dA = ggml_exp (ctx0, ggml_mul (ctx0, model.layers [il].ssm_a , ggml_transpose (ctx0, dt) ));
7000
7007
7001
7008
// => {d_state, d_inner}
7002
- struct ggml_tensor * dB = ggml_out_prod (ctx0, B, ggml_transpose (ctx0, dt) );
7009
+ struct ggml_tensor * dB = ggml_out_prod (ctx0, B, dt );
7003
7010
7004
7011
// => {d_state, d_inner}
7005
7012
cur = ggml_mul (ctx0, dB, ggml_transpose (ctx0, x));
@@ -7014,7 +7021,7 @@ struct llm_build_context {
7014
7021
y = ggml_add (ctx0, y, ggml_mul (ctx0, model.layers [il].ssm_d , x));
7015
7022
y = ggml_mul (ctx0, y, ggml_silu (ctx0, z));
7016
7023
7017
- // {d_inner, n_embd} * {d_inner, 1} = {n_embd, 1}
7024
+ // {d_inner, n_embd} * {d_inner, 1} => {n_embd, 1}
7018
7025
cur = ggml_mul_mat (ctx0, model.layers [il].ssm_out , y);
7019
7026
}
7020
7027
@@ -10722,8 +10729,15 @@ struct llama_context * llama_new_context_with_model(
10722
10729
ctx->rng = std::mt19937 (params.seed );
10723
10730
ctx->logits_all = params.logits_all ;
10724
10731
10725
- const ggml_type type_k = params.type_k ;
10726
- const ggml_type type_v = params.type_v ;
10732
+ ggml_type type_k = params.type_k ;
10733
+ ggml_type type_v = params.type_v ;
10734
+
10735
+ // Mamba (mis)uses the KV cache to store its states
10736
+ if (model->arch == LLM_ARCH_MAMBA) {
10737
+ // it's probably best to keep as much precision as possible for the states
10738
+ type_k = GGML_TYPE_F32; // required by ggml_set for Mamba's conv_state
10739
+ type_v = GGML_TYPE_F32; // required by ggml_mul for Mamba's ssm_state
10740
+ }
10727
10741
10728
10742
GGML_ASSERT (hparams.n_embd_head_k % ggml_blck_size (type_k) == 0 );
10729
10743
GGML_ASSERT (hparams.n_embd_head_v % ggml_blck_size (type_v) == 0 );
0 commit comments