Skip to content

Commit e67c420

Browse files
committed
mamba : refactor recurrent conv, resulting in 20% perf increase
It's still slower than I'd like, but I did not really optimize `ggml_exp` yet. I also refactored `ggml_exp` to work with tensors with more than 2 dimensions.
1 parent a66e7f6 commit e67c420

File tree

2 files changed

+23
-22
lines changed

2 files changed

+23
-22
lines changed

ggml.c

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8421,16 +8421,19 @@ static void ggml_compute_forward_exp_f32(
84218421
return;
84228422
}
84238423

8424-
const int n = ggml_nrows(src0);
8425-
const int nc = src0->ne[0];
8426-
84278424
GGML_ASSERT( dst->nb[0] == sizeof(float));
84288425
GGML_ASSERT(src0->nb[0] == sizeof(float));
84298426

8430-
for (int i = 0; i < n; i++) {
8431-
ggml_vec_exp_f32(nc,
8432-
(float *) ((char *) dst->data + i*( dst->nb[1])),
8433-
(float *) ((char *) src0->data + i*(src0->nb[1])));
8427+
GGML_TENSOR_UNARY_OP_LOCALS
8428+
8429+
for (int64_t i3 = 0; i3 < ne03; i3++) {
8430+
for (int64_t i2 = 0; i2 < ne02; i2++) {
8431+
for (int64_t i1 = 0; i1 < ne01; i1++) {
8432+
float * src_row = (float *) ((char *) src0->data + i1*nb01 + i2*nb02 + i3*nb03);
8433+
float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
8434+
ggml_vec_exp_f32(ne00, dst_row, src_row);
8435+
}
8436+
}
84348437
}
84358438
}
84368439

llama.cpp

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6931,15 +6931,14 @@ struct llm_build_context {
69316931
struct ggml_tensor * cur;
69326932
struct ggml_tensor * inpL;
69336933

6934+
// NOTE: not sure what's the difference between the sequence length and the batch size in the paper.
69346935
// {n_embd, batch}
69356936
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
69366937
cb(inpL, "inp_embd", -1);
69376938

69386939
for (int il = 0; il < n_layer; ++il) {
69396940
// (ab)using the kv cache to store the state
6940-
// NOTE: the conv_state is transposed to ease shifting it.
6941-
// if you figured out a way to shift it without transposing it like this, go ahead and fix this.
6942-
ggml_tensor * conv_state = kv_self.k_l[il]; // {d_inner, d_conv}
6941+
ggml_tensor * conv_state = ggml_reshape_2d(ctx0, kv_self.k_l[il], d_conv, d_inner);
69436942
ggml_tensor * ssm_state = ggml_reshape_2d(ctx0, kv_self.v_l[il], d_state, d_inner);
69446943

69456944
// norm
@@ -6948,33 +6947,32 @@ struct llm_build_context {
69486947
LLM_NORM_RMS, cb, il);
69496948
cb(cur, "attn_norm", il);
69506949

6951-
// {n_embd, batch} * {n_embd, 2*d_inner} = {batch, 2*d_inner}
6952-
struct ggml_tensor * xz = ggml_mul_mat(ctx0, cur, model.layers[il].ssm_in);
6950+
// {n_embd, 2*d_inner} * {n_embd, batch} = {2*d_inner, batch}
6951+
struct ggml_tensor * xz = ggml_mul_mat(ctx0, model.layers[il].ssm_in, cur);
69536952
// split the above in two
69546953
// assuming it's contiguous
6955-
// FIXME: handle batches of more than 1 token
6956-
struct ggml_tensor * x = ggml_view_1d(ctx0, xz, d_inner, 0);
6957-
struct ggml_tensor * z = ggml_view_1d(ctx0, xz, d_inner, ggml_element_size(xz)*d_inner);
6954+
// {d_inner, batch}
6955+
struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
6956+
struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);
69586957

69596958
cur = x;
69606959

69616960
// conv
69626961
{
69636962
// shift conv state left
6964-
conv_state = ggml_set_1d(ctx0, conv_state, ggml_view_1d(ctx0, conv_state, (d_conv - 1)*d_inner, ggml_element_size(conv_state)*d_inner), 0);
6963+
conv_state = ggml_set_2d(ctx0, conv_state, ggml_view_2d(ctx0, conv_state, (d_conv - 1), d_inner, conv_state->nb[1], ggml_element_size(conv_state)*1), conv_state->nb[1], 0);
69656964

69666965
// update last column
6967-
conv_state = ggml_set_1d(ctx0, conv_state, x, ggml_element_size(conv_state)*(d_conv - 1)*d_inner);
6966+
// x here is {d_inner, 1} (a row), but should be {1, d_inner} (a column)
6967+
conv_state = ggml_set_2d(ctx0, conv_state, ggml_cont(ctx0, ggml_transpose(ctx0, x)), conv_state->nb[1], ggml_element_size(conv_state)*(d_conv - 1));
69686968

69696969
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state, ggml_view_tensor(ctx0, kv_self.k_l[il])));
69706970

69716971
// rearrange and sum
6972-
conv_state = ggml_reshape_2d(ctx0, conv_state, d_inner, d_conv);
6973-
// TODO: find a way to directly shift a 2d conv_state, avoiding the need to transpose here.
6974-
conv_state = ggml_cont(ctx0, ggml_transpose(ctx0, conv_state));
6975-
6976-
// --> {1, d_inner}
6972+
// no need to rearrange the conv_state, since it's already in the right shape
6973+
// => {1, d_inner}
69776974
x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_state, model.layers[il].ssm_conv1d));
6975+
// => {d_inner, 1}
69786976
x = ggml_transpose(ctx0, x);
69796977

69806978
// bias

0 commit comments

Comments
 (0)