@@ -6931,15 +6931,14 @@ struct llm_build_context {
6931
6931
struct ggml_tensor * cur;
6932
6932
struct ggml_tensor * inpL;
6933
6933
6934
+ // NOTE: not sure what's the difference between the sequence length and the batch size in the paper.
6934
6935
// {n_embd, batch}
6935
6936
inpL = llm_build_inp_embd (ctx0, hparams, batch, model.tok_embd , lctx.inp_tokens , lctx.inp_embd , cb);
6936
6937
cb (inpL, " inp_embd" , -1 );
6937
6938
6938
6939
for (int il = 0 ; il < n_layer; ++il) {
6939
6940
// (ab)using the kv cache to store the state
6940
- // NOTE: the conv_state is transposed to ease shifting it.
6941
- // if you figured out a way to shift it without transposing it like this, go ahead and fix this.
6942
- ggml_tensor * conv_state = kv_self.k_l [il]; // {d_inner, d_conv}
6941
+ ggml_tensor * conv_state = ggml_reshape_2d (ctx0, kv_self.k_l [il], d_conv, d_inner);
6943
6942
ggml_tensor * ssm_state = ggml_reshape_2d (ctx0, kv_self.v_l [il], d_state, d_inner);
6944
6943
6945
6944
// norm
@@ -6948,33 +6947,32 @@ struct llm_build_context {
6948
6947
LLM_NORM_RMS, cb, il);
6949
6948
cb (cur, " attn_norm" , il);
6950
6949
6951
- // {n_embd, batch } * {n_embd, 2*d_inner } = {batch, 2*d_inner}
6952
- struct ggml_tensor * xz = ggml_mul_mat (ctx0, cur, model.layers [il].ssm_in );
6950
+ // {n_embd, 2*d_inner } * {n_embd, batch } = {2*d_inner, batch }
6951
+ struct ggml_tensor * xz = ggml_mul_mat (ctx0, model.layers [il].ssm_in , cur );
6953
6952
// split the above in two
6954
6953
// assuming it's contiguous
6955
- // FIXME: handle batches of more than 1 token
6956
- struct ggml_tensor * x = ggml_view_1d (ctx0, xz, d_inner, 0 );
6957
- struct ggml_tensor * z = ggml_view_1d (ctx0, xz, d_inner, ggml_element_size (xz)*d_inner);
6954
+ // {d_inner, batch}
6955
+ struct ggml_tensor * x = ggml_view_2d (ctx0, xz, d_inner, xz-> ne [ 1 ], xz-> nb [ 1 ] , 0 );
6956
+ struct ggml_tensor * z = ggml_view_2d (ctx0, xz, d_inner, xz-> ne [ 1 ], xz-> nb [ 1 ] , ggml_element_size (xz)*d_inner);
6958
6957
6959
6958
cur = x;
6960
6959
6961
6960
// conv
6962
6961
{
6963
6962
// shift conv state left
6964
- conv_state = ggml_set_1d (ctx0, conv_state, ggml_view_1d (ctx0, conv_state, (d_conv - 1 )* d_inner, ggml_element_size (conv_state)*d_inner) , 0 );
6963
+ conv_state = ggml_set_2d (ctx0, conv_state, ggml_view_2d (ctx0, conv_state, (d_conv - 1 ), d_inner, conv_state-> nb [ 1 ], ggml_element_size (conv_state)*1 ), conv_state-> nb [ 1 ] , 0 );
6965
6964
6966
6965
// update last column
6967
- conv_state = ggml_set_1d (ctx0, conv_state, x, ggml_element_size (conv_state)*(d_conv - 1 )*d_inner);
6966
+ // x here is {d_inner, 1} (a row), but should be {1, d_inner} (a column)
6967
+ conv_state = ggml_set_2d (ctx0, conv_state, ggml_cont (ctx0, ggml_transpose (ctx0, x)), conv_state->nb [1 ], ggml_element_size (conv_state)*(d_conv - 1 ));
6968
6968
6969
6969
ggml_build_forward_expand (gf, ggml_cpy (ctx0, conv_state, ggml_view_tensor (ctx0, kv_self.k_l [il])));
6970
6970
6971
6971
// rearrange and sum
6972
- conv_state = ggml_reshape_2d (ctx0, conv_state, d_inner, d_conv);
6973
- // TODO: find a way to directly shift a 2d conv_state, avoiding the need to transpose here.
6974
- conv_state = ggml_cont (ctx0, ggml_transpose (ctx0, conv_state));
6975
-
6976
- // --> {1, d_inner}
6972
+ // no need to rearrange the conv_state, since it's already in the right shape
6973
+ // => {1, d_inner}
6977
6974
x = ggml_sum_rows (ctx0, ggml_mul (ctx0, conv_state, model.layers [il].ssm_conv1d ));
6975
+ // => {d_inner, 1}
6978
6976
x = ggml_transpose (ctx0, x);
6979
6977
6980
6978
// bias
0 commit comments