Skip to content

Commit a66e7f6

Browse files
committed
convert : optionally use d_conv and d_state from config.json for Mamba
1 parent 5b0373e commit a66e7f6

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

convert-hf-to-gguf.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1502,15 +1502,19 @@ def write_tensors(self):
15021502
class MambaModel(Model):
15031503
def set_gguf_parameters(self):
15041504
d_model = self.hparams["d_model"]
1505+
d_inner = self.hparams.get("d_inner", 2 * d_model)
1506+
# Fail early for models which don't have a block expansion factor of 2
1507+
assert d_inner == 2 * d_model
1508+
15051509
self.gguf_writer.add_name(self.dir_model.name)
1506-
self.gguf_writer.add_context_length(128) # arbitrary value; it shouldn't be important for Mamba
1510+
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
15071511
self.gguf_writer.add_embedding_length(d_model)
15081512
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
1509-
self.gguf_writer.add_head_count(2 * d_model) # d_inner
1513+
self.gguf_writer.add_head_count(d_inner)
15101514
self.gguf_writer.add_block_count(self.hparams["n_layer"])
1511-
self.gguf_writer.add_layer_norm_rms_eps(1e-5)
1512-
self.gguf_writer.add_key_length(4) # d_conv
1513-
self.gguf_writer.add_value_length(16) # d_state
1515+
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5))
1516+
self.gguf_writer.add_key_length(self.hparams.get("d_conv", 4))
1517+
self.gguf_writer.add_value_length(self.hparams.get("d_state", 16))
15141518
self.gguf_writer.add_file_type(self.ftype)
15151519

15161520
def write_tensors(self):

0 commit comments

Comments
 (0)