@@ -1579,7 +1579,6 @@ struct llama_layer {
1579
1579
struct ggml_tensor * ffn_up_b; // b3
1580
1580
struct ggml_tensor * ffn_act;
1581
1581
1582
-
1583
1582
// mamba proj
1584
1583
struct ggml_tensor * ssm_in;
1585
1584
struct ggml_tensor * ssm_x;
@@ -3107,6 +3106,7 @@ static void llm_load_hparams(
3107
3106
} break ;
3108
3107
case LLM_ARCH_MAMBA:
3109
3108
{
3109
+ ml.get_key (LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps );
3110
3110
switch (hparams.n_layer ) {
3111
3111
case 24 :
3112
3112
switch (hparams.n_embd ) {
@@ -3127,7 +3127,7 @@ static void llm_load_hparams(
3127
3127
} break ;
3128
3128
default : model.type = e_model::MODEL_UNKNOWN;
3129
3129
}
3130
- }
3130
+ } break ;
3131
3131
default : (void )0 ;
3132
3132
}
3133
3133
@@ -3591,7 +3591,10 @@ static bool llm_load_tensors(
3591
3591
const int64_t n_vocab = hparams.n_vocab ;
3592
3592
const int64_t n_ff = hparams.n_ff ;
3593
3593
3594
- GGML_ASSERT (n_embd_gqa == n_embd_k_gqa);
3594
+ // Mamba uses these in its own way
3595
+ if (model.arch != LLM_ARCH_MAMBA) {
3596
+ GGML_ASSERT (n_embd_gqa == n_embd_k_gqa);
3597
+ }
3595
3598
3596
3599
ggml_context * ctx_input = ctx_map.at (model.buft_input .buft );
3597
3600
ggml_context * ctx_output = ctx_map.at (model.buft_output .buft );
@@ -4176,19 +4179,21 @@ static bool llm_load_tensors(
4176
4179
} break ;
4177
4180
case LLM_ARCH_MAMBA:
4178
4181
{
4179
- model.tok_embd = ml.create_tensor (ctx_input, tn (LLM_TENSOR_TOKEN_EMBD, " weight" ), {n_embd, n_vocab});
4180
-
4181
4182
const int64_t d_conv = hparams.n_embd_head_k ;
4182
4183
const int64_t d_state = hparams.n_embd_head_v ;
4183
4184
const int64_t d_inner = hparams.n_head ;
4184
4185
// FIXME: ceiling instead of floor
4185
4186
const int64_t dt_rank = n_embd / 16 ;
4186
4187
GGML_ASSERT (2 * n_embd == d_inner);
4188
+ // round up the vocab size to the next multiple of 8
4189
+ const int64_t rounded_vocab = (n_vocab + 7 ) & -8 ;
4190
+
4191
+ model.tok_embd = ml.create_tensor (ctx_input, tn (LLM_TENSOR_TOKEN_EMBD, " weight" ), {n_embd, rounded_vocab});
4187
4192
4188
4193
// output
4189
4194
{
4190
4195
model.output_norm = ml.create_tensor (ctx_output, tn (LLM_TENSOR_OUTPUT_NORM, " weight" ), {n_embd});
4191
- model.output = ml.create_tensor (ctx_output_split, tn (LLM_TENSOR_OUTPUT, " weight" ), {n_embd, n_vocab });
4196
+ model.output = ml.create_tensor (ctx_output_split, tn (LLM_TENSOR_OUTPUT, " weight" ), {n_embd, rounded_vocab });
4192
4197
}
4193
4198
4194
4199
for (int i = 0 ; i < n_layer; ++i) {
@@ -4205,17 +4210,17 @@ static bool llm_load_tensors(
4205
4210
4206
4211
layer.ssm_in = ml.create_tensor (ctx_split, tn (LLM_TENSOR_SSM_IN, " weight" , i), {n_embd, 2 *d_inner});
4207
4212
4208
- layer.ssm_conv1d = ml.create_tensor (ctx_split, tn (LLM_TENSOR_SSM_CONV1D, " weight" , i), {d_conv, 1 , d_inner});
4213
+ layer.ssm_conv1d = ml.create_tensor (ctx_split, tn (LLM_TENSOR_SSM_CONV1D, " weight" , i), {d_conv, d_inner});
4209
4214
layer.ssm_conv1d_b = ml.create_tensor (ctx_layer, tn (LLM_TENSOR_SSM_CONV1D, " bias" , i), {d_inner});
4210
4215
4211
4216
layer.ssm_x = ml.create_tensor (ctx_split, tn (LLM_TENSOR_SSM_X, " weight" , i), {d_inner, dt_rank + 2 *d_state});
4212
4217
4213
4218
layer.ssm_dt = ml.create_tensor (ctx_split, tn (LLM_TENSOR_SSM_DT, " weight" , i), {dt_rank, d_inner});
4214
4219
layer.ssm_dt_b = ml.create_tensor (ctx_layer, tn (LLM_TENSOR_SSM_DT, " bias" , i), {d_inner});
4215
4220
4216
- // FIXME: maybe no suffix for these
4217
- layer.ssm_a = ml.create_tensor (ctx_split, tn (LLM_TENSOR_SSM_A, " weight " , i), {d_state, d_inner});
4218
- layer.ssm_d = ml.create_tensor (ctx_layer, tn (LLM_TENSOR_SSM_D, " weight " , i), {d_inner});
4221
+ // no "weight" suffix for these
4222
+ layer.ssm_a = ml.create_tensor (ctx_split, tn (LLM_TENSOR_SSM_A, i), {d_state, d_inner});
4223
+ layer.ssm_d = ml.create_tensor (ctx_layer, tn (LLM_TENSOR_SSM_D, i), {d_inner});
4219
4224
4220
4225
// out_proj
4221
4226
layer.ssm_out = ml.create_tensor (ctx_split, tn (LLM_TENSOR_SSM_OUT, " weight" , i), {d_inner, n_embd});
@@ -6909,16 +6914,18 @@ struct llm_build_context {
6909
6914
return gf;
6910
6915
}
6911
6916
6912
- struct ggml_cgraph * build_mamba (bool use_conv ) {
6917
+ struct ggml_cgraph * build_mamba () {
6913
6918
struct ggml_cgraph * gf = ggml_new_graph_custom (ctx0, LLAMA_MAX_NODES, false );
6914
6919
6920
+ const bool use_conv = batch.n_tokens > 1 ;
6915
6921
GGML_ASSERT (use_conv == false ); // TODO: implement
6916
6922
6917
- const int64_t d_model = hparams.n_embd ;
6918
- const int64_t d_inner = hparams.n_head ;
6923
+ // hopefully the compiler does constant folding
6924
+ const int64_t d_model = n_embd;
6925
+ const int64_t d_inner = n_head;
6919
6926
GGML_ASSERT (2 * d_model == d_inner);
6920
- const int64_t d_conv = hparams. n_embd_head_k ;
6921
- const int64_t d_state = hparams. n_embd_head_v ;
6927
+ const int64_t d_conv = n_embd_head_k;
6928
+ const int64_t d_state = n_embd_head_v;
6922
6929
const int64_t dt_rank = d_model / 16 ;
6923
6930
6924
6931
struct ggml_tensor * cur;
@@ -6930,8 +6937,10 @@ struct llm_build_context {
6930
6937
6931
6938
for (int il = 0 ; il < n_layer; ++il) {
6932
6939
// (ab)using the kv cache to store the state
6933
- ggml_tensor * conv_state = kv_self.k_l [il]; // {d_conv, d_inner}
6934
- ggml_tensor * ssm_state = kv_self.v_l [il]; // {d_state, d_inner}
6940
+ // NOTE: the conv_state is transposed to ease shifting it.
6941
+ // if you figured out a way to shift it without transposing it like this, go ahead and fix this.
6942
+ ggml_tensor * conv_state = kv_self.k_l [il]; // {d_inner, d_conv}
6943
+ ggml_tensor * ssm_state = ggml_reshape_2d (ctx0, kv_self.v_l [il], d_state, d_inner);
6935
6944
6936
6945
// norm
6937
6946
cur = llm_build_norm (ctx0, inpL, hparams,
@@ -6943,36 +6952,73 @@ struct llm_build_context {
6943
6952
// {n_embd, batch} * {n_embd, 2*d_inner} = {batch, 2*d_inner}
6944
6953
struct ggml_tensor * xz = ggml_mul_mat (ctx0, cur, model.layers [il].ssm_in );
6945
6954
// split the above in two
6955
+ // assuming it's contiguous
6956
+ // FIXME: handle batches of more than 1 token
6946
6957
struct ggml_tensor * x = ggml_view_1d (ctx0, xz, d_inner, 0 );
6947
- struct ggml_tensor * z = ggml_view_1d (ctx0, xz, d_inner, d_inner);
6958
+ struct ggml_tensor * z = ggml_view_1d (ctx0, xz, d_inner, ggml_element_size (xz)*d_inner);
6959
+
6960
+ cur = x;
6948
6961
6949
- // FIXME: figure out when to transpose
6950
6962
// conv
6951
6963
{
6952
- // TODO: figure out how to do a row-wise dot product
6953
- // TODO: use the kv-cache to store the state
6954
- kv_self.k_l [il];
6964
+ // shift conv state left
6965
+ conv_state = ggml_set_1d (ctx0, conv_state, ggml_view_1d (ctx0, conv_state, (d_conv - 1 )*d_inner, ggml_element_size (conv_state)*d_inner), 0 );
6966
+
6967
+ // update last column
6968
+ conv_state = ggml_set_1d (ctx0, conv_state, x, ggml_element_size (conv_state)*(d_conv - 1 )*d_inner);
6969
+
6970
+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, conv_state, ggml_view_tensor (ctx0, kv_self.k_l [il])));
6955
6971
6956
- // FIXME: this is wrong
6957
- cur = ggml_conv_1d (ctx0, cur, model.layers [il].ssm_conv1d , 1 , d_conv - 1 , 1 );
6972
+ // rearrange and sum
6973
+ conv_state = ggml_reshape_2d (ctx0, conv_state, d_inner, d_conv);
6974
+ // TODO: find a way to directly shift a 2d conv_state, avoiding the need to transpose here.
6975
+ conv_state = ggml_cont (ctx0, ggml_transpose (ctx0, conv_state));
6958
6976
6959
- cur = ggml_add (ctx0, cur, model.layers [il].ssm_conv1d_b );
6977
+ // --> {1, d_inner}
6978
+ x = ggml_sum_rows (ctx0, ggml_mul (ctx0, conv_state, model.layers [il].ssm_conv1d ));
6979
+ x = ggml_transpose (ctx0, x);
6960
6980
6961
- // TODO: there's some SiLU in there (but no ffn? or is the conv an ffn?)
6962
- cur = ggml_silu (ctx0, cur);
6981
+ // bias
6982
+ x = ggml_add (ctx0, x, model.layers [il].ssm_conv1d_b );
6983
+
6984
+ x = ggml_silu (ctx0, x);
6963
6985
}
6964
6986
6965
6987
// ssm
6966
6988
{
6989
+ // {2*n_embd, batch} * {2*n_embd, dt_rank + 2*d_state} = {batch, dt_rank + 2*d_state}
6990
+ struct ggml_tensor * x_db = ggml_mul_mat (ctx0, x, model.layers [il].ssm_x );
6991
+ // FIXME: handle batches of more than 1 token
6992
+ struct ggml_tensor * dt = ggml_view_1d (ctx0, x_db, dt_rank, 0 );
6993
+ struct ggml_tensor * B = ggml_view_1d (ctx0, x_db, d_state, ggml_element_size (x_db)*dt_rank);
6994
+ struct ggml_tensor * C = ggml_view_1d (ctx0, x_db, d_state, ggml_element_size (x_db)*(dt_rank+d_state));
6967
6995
6968
- // TODO: use ggml_soft_plus here
6996
+ // {dt_rank} * {dt_rank, d_inner} = {1, d_inner}
6997
+ dt = ggml_mul_mat (ctx0, dt, model.layers [il].ssm_dt );
6998
+ dt = ggml_add (ctx0, dt, ggml_transpose (ctx0, model.layers [il].ssm_dt_b ));
6999
+ dt = ggml_soft_plus (ctx0, dt);
6969
7000
6970
- }
7001
+ // => {d_state, d_inner}
7002
+ struct ggml_tensor * dA = ggml_exp (ctx0, ggml_mul (ctx0, model.layers [il].ssm_a , dt));
6971
7003
6972
- // TODO: there's some SiLU again towards the end. Can the `llm_build_ffn` helper be used?
6973
- // Maybe the best way is to implement it, _then_ check if that helper would do the same thing.
6974
- // discretize
6975
- {
7004
+ // => {d_state, d_inner}
7005
+ struct ggml_tensor * dB = ggml_out_prod (ctx0, B, ggml_transpose (ctx0, dt));
7006
+
7007
+ // => {d_state, d_inner}
7008
+ cur = ggml_mul (ctx0, dB, ggml_transpose (ctx0, x));
7009
+
7010
+ ssm_state = ggml_add (ctx0, ggml_mul (ctx0, ssm_state, dA), cur);
7011
+
7012
+ ggml_build_forward_expand (gf, ggml_cpy (ctx0, ssm_state, ggml_view_tensor (ctx0, kv_self.v_l [il])));
7013
+
7014
+ // row-wise dot product ("dn,n->d")
7015
+ // {d_state, d_inner} * {d_state} => {d_inner, 1}
7016
+ struct ggml_tensor * y = ggml_mul_mat (ctx0, ssm_state, C);
7017
+ y = ggml_add (ctx0, y, ggml_mul (ctx0, model.layers [il].ssm_d , x));
7018
+ y = ggml_mul (ctx0, y, ggml_silu (ctx0, z));
7019
+
7020
+ // {d_inner, n_embd} * {d_inner, 1} = {n_embd, 1}
7021
+ cur = ggml_mul_mat (ctx0, model.layers [il].ssm_out , y);
6976
7022
}
6977
7023
6978
7024
// residual
@@ -6983,11 +7029,8 @@ struct llm_build_context {
6983
7029
inpL = cur;
6984
7030
}
6985
7031
6986
- // the last step of each layer already makes these equivalent
6987
- // cur = inpL;
6988
-
6989
7032
// final rmsnorm
6990
- cur = llm_build_norm (ctx0, cur , hparams,
7033
+ cur = llm_build_norm (ctx0, inpL , hparams,
6991
7034
model.output_norm , NULL ,
6992
7035
LLM_NORM_RMS, cb, -1 );
6993
7036
cb (cur, " result_norm" , -1 );
@@ -7165,7 +7208,7 @@ static struct ggml_cgraph * llama_build_graph(
7165
7208
} break ;
7166
7209
case LLM_ARCH_MAMBA:
7167
7210
{
7168
- result = llm.build_mamba (/* use_conv = */ batch. n_tokens > 1 );
7211
+ result = llm.build_mamba ();
7169
7212
} break ;
7170
7213
default :
7171
7214
GGML_ASSERT (false );
0 commit comments