Skip to content

Commit 711c77b

Browse files
committed
mamba : recurrent inference almost works, but incoherent
1 parent 70d84e8 commit 711c77b

File tree

3 files changed

+129
-42
lines changed

3 files changed

+129
-42
lines changed

convert-hf-to-gguf.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
261261
if arch == "InternLM2ForCausalLM":
262262
return gguf.MODEL_ARCH.INTERNLM2
263263
if arch == "MambaForCausalLM":
264-
return gguf.MODEL_ARCH
264+
return gguf.MODEL_ARCH.MAMBA
265265

266266
raise NotImplementedError(f'Architecture "{arch}" not supported!')
267267

@@ -1503,13 +1503,57 @@ class MambaModel(Model):
15031503
def set_gguf_parameters(self):
15041504
d_model = self.hparams["d_model"]
15051505
self.gguf_writer.add_name(self.dir_model.name)
1506+
self.gguf_writer.add_context_length(128) # arbitrary value; it shouldn't be important for Mamba
15061507
self.gguf_writer.add_embedding_length(d_model)
1507-
self.gguf_writer.add_block_count(self.hparams["n_layer"])
1508+
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
15081509
self.gguf_writer.add_head_count(2 * d_model) # d_inner
1510+
self.gguf_writer.add_block_count(self.hparams["n_layer"])
1511+
self.gguf_writer.add_layer_norm_rms_eps(1e-5)
15091512
self.gguf_writer.add_key_length(4) # d_conv
15101513
self.gguf_writer.add_value_length(16) # d_state
15111514
self.gguf_writer.add_file_type(self.ftype)
15121515

1516+
def write_tensors(self):
1517+
block_count = self.hparams["n_layer"]
1518+
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
1519+
for name, data_torch in self.get_tensors():
1520+
old_dtype = data_torch.dtype
1521+
1522+
# convert any unsupported data types to float32
1523+
if data_torch.dtype not in (torch.float16, torch.float32):
1524+
data_torch = data_torch.to(torch.float32)
1525+
1526+
# map tensor names
1527+
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
1528+
if new_name is None:
1529+
print(f"Can not map tensor {name!r}")
1530+
sys.exit()
1531+
1532+
if name.endswith(".A_log"):
1533+
print("A_log --> A ==> " + new_name)
1534+
data_torch = -torch.exp(data_torch)
1535+
1536+
data = data_torch.squeeze().numpy()
1537+
1538+
n_dims = len(data.shape)
1539+
data_dtype = data.dtype
1540+
1541+
# if f32 desired, convert any float16 to float32
1542+
if self.ftype == 0 and data_dtype == np.float16:
1543+
data = data.astype(np.float32)
1544+
1545+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
1546+
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
1547+
data = data.astype(np.float32)
1548+
1549+
# if f16 desired, convert big float32 2-dim weight tensors to float16
1550+
if self.ftype == 1 and data_dtype == np.float32 and new_name.removesuffix(".weight").endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
1551+
data = data.astype(np.float16)
1552+
1553+
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
1554+
1555+
self.gguf_writer.add_tensor(new_name, data)
1556+
15131557
###### CONVERSION LOGIC ######
15141558

15151559

ggml.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5143,7 +5143,7 @@ struct ggml_tensor * ggml_soft_max_back_inplace(
51435143

51445144
// ggml_soft_plus
51455145

5146-
struct ggml_tensor * ggml_soft_plus_impl(
5146+
static struct ggml_tensor * ggml_soft_plus_impl(
51475147
struct ggml_context * ctx,
51485148
struct ggml_tensor * a,
51495149
bool inplace) {
@@ -15193,7 +15193,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
1519315193
case GGML_OP_SOFT_PLUS:
1519415194
{
1519515195
ggml_compute_forward_soft_plus(params, tensor->src[0], tensor);
15196-
}
15196+
} break;
1519715197
case GGML_OP_ROPE:
1519815198
{
1519915199
ggml_compute_forward_rope(params, tensor->src[0], tensor->src[1], tensor);

llama.cpp

Lines changed: 81 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1579,7 +1579,6 @@ struct llama_layer {
15791579
struct ggml_tensor * ffn_up_b; // b3
15801580
struct ggml_tensor * ffn_act;
15811581

1582-
15831582
// mamba proj
15841583
struct ggml_tensor * ssm_in;
15851584
struct ggml_tensor * ssm_x;
@@ -3107,6 +3106,7 @@ static void llm_load_hparams(
31073106
} break;
31083107
case LLM_ARCH_MAMBA:
31093108
{
3109+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
31103110
switch (hparams.n_layer) {
31113111
case 24:
31123112
switch (hparams.n_embd) {
@@ -3127,7 +3127,7 @@ static void llm_load_hparams(
31273127
} break;
31283128
default: model.type = e_model::MODEL_UNKNOWN;
31293129
}
3130-
}
3130+
} break;
31313131
default: (void)0;
31323132
}
31333133

@@ -3591,7 +3591,10 @@ static bool llm_load_tensors(
35913591
const int64_t n_vocab = hparams.n_vocab;
35923592
const int64_t n_ff = hparams.n_ff;
35933593

3594-
GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
3594+
// Mamba uses these in its own way
3595+
if (model.arch != LLM_ARCH_MAMBA) {
3596+
GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
3597+
}
35953598

35963599
ggml_context * ctx_input = ctx_map.at(model.buft_input.buft);
35973600
ggml_context * ctx_output = ctx_map.at(model.buft_output.buft);
@@ -4176,19 +4179,21 @@ static bool llm_load_tensors(
41764179
} break;
41774180
case LLM_ARCH_MAMBA:
41784181
{
4179-
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4180-
41814182
const int64_t d_conv = hparams.n_embd_head_k;
41824183
const int64_t d_state = hparams.n_embd_head_v;
41834184
const int64_t d_inner = hparams.n_head;
41844185
// FIXME: ceiling instead of floor
41854186
const int64_t dt_rank = n_embd / 16;
41864187
GGML_ASSERT(2 * n_embd == d_inner);
4188+
// round up the vocab size to the next multiple of 8
4189+
const int64_t rounded_vocab = (n_vocab + 7) & -8;
4190+
4191+
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, rounded_vocab});
41874192

41884193
// output
41894194
{
41904195
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
4191-
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
4196+
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, rounded_vocab});
41924197
}
41934198

41944199
for (int i = 0; i < n_layer; ++i) {
@@ -4205,17 +4210,17 @@ static bool llm_load_tensors(
42054210

42064211
layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner});
42074212

4208-
layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, 1, d_inner});
4213+
layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner});
42094214
layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner});
42104215

42114216
layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state});
42124217

42134218
layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner});
42144219
layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner});
42154220

4216-
// FIXME: maybe no suffix for these
4217-
layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, "weight", i), {d_state, d_inner});
4218-
layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, "weight", i), {d_inner});
4221+
// no "weight" suffix for these
4222+
layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner});
4223+
layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner});
42194224

42204225
// out_proj
42214226
layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
@@ -6909,16 +6914,18 @@ struct llm_build_context {
69096914
return gf;
69106915
}
69116916

6912-
struct ggml_cgraph * build_mamba(bool use_conv) {
6917+
struct ggml_cgraph * build_mamba() {
69136918
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
69146919

6920+
const bool use_conv = batch.n_tokens > 1;
69156921
GGML_ASSERT(use_conv == false); // TODO: implement
69166922

6917-
const int64_t d_model = hparams.n_embd;
6918-
const int64_t d_inner = hparams.n_head;
6923+
// hopefully the compiler does constant folding
6924+
const int64_t d_model = n_embd;
6925+
const int64_t d_inner = n_head;
69196926
GGML_ASSERT(2 * d_model == d_inner);
6920-
const int64_t d_conv = hparams.n_embd_head_k;
6921-
const int64_t d_state = hparams.n_embd_head_v;
6927+
const int64_t d_conv = n_embd_head_k;
6928+
const int64_t d_state = n_embd_head_v;
69226929
const int64_t dt_rank = d_model / 16;
69236930

69246931
struct ggml_tensor * cur;
@@ -6930,8 +6937,10 @@ struct llm_build_context {
69306937

69316938
for (int il = 0; il < n_layer; ++il) {
69326939
// (ab)using the kv cache to store the state
6933-
ggml_tensor * conv_state = kv_self.k_l[il]; // {d_conv, d_inner}
6934-
ggml_tensor * ssm_state = kv_self.v_l[il]; // {d_state, d_inner}
6940+
// NOTE: the conv_state is transposed to ease shifting it.
6941+
// if you figured out a way to shift it without transposing it like this, go ahead and fix this.
6942+
ggml_tensor * conv_state = kv_self.k_l[il]; // {d_inner, d_conv}
6943+
ggml_tensor * ssm_state = ggml_reshape_2d(ctx0, kv_self.v_l[il], d_state, d_inner);
69356944

69366945
// norm
69376946
cur = llm_build_norm(ctx0, inpL, hparams,
@@ -6943,36 +6952,73 @@ struct llm_build_context {
69436952
// {n_embd, batch} * {n_embd, 2*d_inner} = {batch, 2*d_inner}
69446953
struct ggml_tensor * xz = ggml_mul_mat(ctx0, cur, model.layers[il].ssm_in);
69456954
// split the above in two
6955+
// assuming it's contiguous
6956+
// FIXME: handle batches of more than 1 token
69466957
struct ggml_tensor * x = ggml_view_1d(ctx0, xz, d_inner, 0);
6947-
struct ggml_tensor * z = ggml_view_1d(ctx0, xz, d_inner, d_inner);
6958+
struct ggml_tensor * z = ggml_view_1d(ctx0, xz, d_inner, ggml_element_size(xz)*d_inner);
6959+
6960+
cur = x;
69486961

6949-
// FIXME: figure out when to transpose
69506962
// conv
69516963
{
6952-
// TODO: figure out how to do a row-wise dot product
6953-
// TODO: use the kv-cache to store the state
6954-
kv_self.k_l[il];
6964+
// shift conv state left
6965+
conv_state = ggml_set_1d(ctx0, conv_state, ggml_view_1d(ctx0, conv_state, (d_conv - 1)*d_inner, ggml_element_size(conv_state)*d_inner), 0);
6966+
6967+
// update last column
6968+
conv_state = ggml_set_1d(ctx0, conv_state, x, ggml_element_size(conv_state)*(d_conv - 1)*d_inner);
6969+
6970+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state, ggml_view_tensor(ctx0, kv_self.k_l[il])));
69556971

6956-
// FIXME: this is wrong
6957-
cur = ggml_conv_1d(ctx0, cur, model.layers[il].ssm_conv1d, 1, d_conv - 1, 1);
6972+
// rearrange and sum
6973+
conv_state = ggml_reshape_2d(ctx0, conv_state, d_inner, d_conv);
6974+
// TODO: find a way to directly shift a 2d conv_state, avoiding the need to transpose here.
6975+
conv_state = ggml_cont(ctx0, ggml_transpose(ctx0, conv_state));
69586976

6959-
cur = ggml_add(ctx0, cur, model.layers[il].ssm_conv1d_b);
6977+
// --> {1, d_inner}
6978+
x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_state, model.layers[il].ssm_conv1d));
6979+
x = ggml_transpose(ctx0, x);
69606980

6961-
// TODO: there's some SiLU in there (but no ffn? or is the conv an ffn?)
6962-
cur = ggml_silu(ctx0, cur);
6981+
// bias
6982+
x = ggml_add(ctx0, x, model.layers[il].ssm_conv1d_b);
6983+
6984+
x = ggml_silu(ctx0, x);
69636985
}
69646986

69656987
// ssm
69666988
{
6989+
// {2*n_embd, batch} * {2*n_embd, dt_rank + 2*d_state} = {batch, dt_rank + 2*d_state}
6990+
struct ggml_tensor * x_db = ggml_mul_mat(ctx0, x, model.layers[il].ssm_x);
6991+
// FIXME: handle batches of more than 1 token
6992+
struct ggml_tensor * dt = ggml_view_1d(ctx0, x_db, dt_rank, 0);
6993+
struct ggml_tensor * B = ggml_view_1d(ctx0, x_db, d_state, ggml_element_size(x_db)*dt_rank);
6994+
struct ggml_tensor * C = ggml_view_1d(ctx0, x_db, d_state, ggml_element_size(x_db)*(dt_rank+d_state));
69676995

6968-
// TODO: use ggml_soft_plus here
6996+
// {dt_rank} * {dt_rank, d_inner} = {1, d_inner}
6997+
dt = ggml_mul_mat(ctx0, dt, model.layers[il].ssm_dt);
6998+
dt = ggml_add(ctx0, dt, ggml_transpose(ctx0, model.layers[il].ssm_dt_b));
6999+
dt = ggml_soft_plus(ctx0, dt);
69697000

6970-
}
7001+
// => {d_state, d_inner}
7002+
struct ggml_tensor * dA = ggml_exp(ctx0, ggml_mul(ctx0, model.layers[il].ssm_a, dt));
69717003

6972-
// TODO: there's some SiLU again towards the end. Can the `llm_build_ffn` helper be used?
6973-
// Maybe the best way is to implement it, _then_ check if that helper would do the same thing.
6974-
// discretize
6975-
{
7004+
// => {d_state, d_inner}
7005+
struct ggml_tensor * dB = ggml_out_prod(ctx0, B, ggml_transpose(ctx0, dt));
7006+
7007+
// => {d_state, d_inner}
7008+
cur = ggml_mul(ctx0, dB, ggml_transpose(ctx0, x));
7009+
7010+
ssm_state = ggml_add(ctx0, ggml_mul(ctx0, ssm_state, dA), cur);
7011+
7012+
ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_state, ggml_view_tensor(ctx0, kv_self.v_l[il])));
7013+
7014+
// row-wise dot product ("dn,n->d")
7015+
// {d_state, d_inner} * {d_state} => {d_inner, 1}
7016+
struct ggml_tensor * y = ggml_mul_mat(ctx0, ssm_state, C);
7017+
y = ggml_add(ctx0, y, ggml_mul(ctx0, model.layers[il].ssm_d, x));
7018+
y = ggml_mul(ctx0, y, ggml_silu(ctx0, z));
7019+
7020+
// {d_inner, n_embd} * {d_inner, 1} = {n_embd, 1}
7021+
cur = ggml_mul_mat(ctx0, model.layers[il].ssm_out, y);
69767022
}
69777023

69787024
// residual
@@ -6983,11 +7029,8 @@ struct llm_build_context {
69837029
inpL = cur;
69847030
}
69857031

6986-
// the last step of each layer already makes these equivalent
6987-
// cur = inpL;
6988-
69897032
// final rmsnorm
6990-
cur = llm_build_norm(ctx0, cur, hparams,
7033+
cur = llm_build_norm(ctx0, inpL, hparams,
69917034
model.output_norm, NULL,
69927035
LLM_NORM_RMS, cb, -1);
69937036
cb(cur, "result_norm", -1);
@@ -7165,7 +7208,7 @@ static struct ggml_cgraph * llama_build_graph(
71657208
} break;
71667209
case LLM_ARCH_MAMBA:
71677210
{
7168-
result = llm.build_mamba(/* use_conv =*/ batch.n_tokens > 1);
7211+
result = llm.build_mamba();
71697212
} break;
71707213
default:
71717214
GGML_ASSERT(false);

0 commit comments

Comments
 (0)