Skip to content

Commit e83d270

Browse files
committed
convert : adapt MiniCPM3 to separate rope_freqs insertion
MiniCPM3's tokenizer is treated as a SentencePiece tokenizer to avoid having to run its custom Python code which mixes tokenization in the same file as tool calls. gguf-py : add long and short RoPE factors to tensor mappings Empty, but the key names are used to populate the mappings.
1 parent ed0f2c4 commit e83d270

File tree

3 files changed

+17
-13
lines changed

3 files changed

+17
-13
lines changed

convert_hf_to_gguf.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1862,8 +1862,6 @@ class MiniCPM3Model(Model):
18621862
def set_gguf_parameters(self):
18631863
hparams = self.hparams
18641864

1865-
rope_dims = hparams["qk_rope_head_dim"]
1866-
18671865
self.gguf_writer.add_file_type(self.ftype)
18681866
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
18691867
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
@@ -1879,24 +1877,25 @@ def set_gguf_parameters(self):
18791877
self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
18801878
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
18811879

1880+
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
18821881
rope_scaling = self.find_hparam(['rope_scaling'], True)
1883-
if rope_scaling is None:
1884-
return
1882+
if rope_scaling is not None:
1883+
rope_dims = self.hparams["qk_rope_head_dim"]
18851884

1886-
long_factors = rope_scaling.get('long_factor', None)
1887-
short_factors = rope_scaling.get('short_factor', None)
1885+
long_factors = rope_scaling.get('long_factor', None)
1886+
short_factors = rope_scaling.get('short_factor', None)
18881887

1889-
if long_factors is None or short_factors is None:
1890-
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
1888+
if long_factors is None or short_factors is None:
1889+
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
18911890

1892-
if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
1893-
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
1891+
if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
1892+
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
18941893

1895-
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32))
1896-
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32))
1894+
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32))
1895+
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))
18971896

18981897
def set_vocab(self):
1899-
self._set_vocab_llama_hf()
1898+
self._set_vocab_sentencepiece()
19001899

19011900
def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
19021901
if n_kv_head is not None and n_head != n_kv_head:

gguf-py/gguf/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,8 @@ class MODEL_TENSOR(IntEnum):
877877
MODEL_TENSOR.TOKEN_EMBD,
878878
MODEL_TENSOR.OUTPUT_NORM,
879879
MODEL_TENSOR.OUTPUT,
880+
MODEL_TENSOR.ROPE_FACTORS_LONG,
881+
MODEL_TENSOR.ROPE_FACTORS_SHORT,
880882
MODEL_TENSOR.ATTN_NORM,
881883
MODEL_TENSOR.ATTN_Q_A,
882884
MODEL_TENSOR.ATTN_Q_B,

gguf-py/gguf/tensor_mapping.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ class TensorNameMap:
8787
"rope.freqs", # llama-pth
8888
"rotary_pos_emb.inv_freq", # chatglm
8989
),
90+
91+
MODEL_TENSOR.ROPE_FACTORS_LONG: (),
92+
MODEL_TENSOR.ROPE_FACTORS_SHORT: (),
9093
}
9194

9295
block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {

0 commit comments

Comments
 (0)