Skip to content

Commit 15ebe59

Browse files
authored
convert : update phi-2 to latest HF repo (#4903)
* convert : update phi-2 to latest HF repo ggml-ci * py : try to fix flake stuff
1 parent de473f5 commit 15ebe59

File tree

4 files changed

+65
-21
lines changed

4 files changed

+65
-21
lines changed

convert-hf-to-gguf.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@
2323
import gguf
2424

2525

26+
# check for any of the given keys in the dictionary and return the value of the first key found
27+
def get_key_opts(d, keys):
28+
for k in keys:
29+
if k in d:
30+
return d[k]
31+
print(f"Could not find any of {keys}")
32+
sys.exit()
33+
34+
2635
###### MODEL DEFINITIONS ######
2736

2837
class SentencePieceTokenTypes(IntEnum):
@@ -257,10 +266,11 @@ def _set_vocab_gpt2(self):
257266
toktypes.append(gguf.TokenType.USER_DEFINED)
258267
elif reverse_vocab[i] in added_vocab:
259268
tokens.append(reverse_vocab[i])
260-
if tokenizer.added_tokens_decoder[i].special:
261-
toktypes.append(gguf.TokenType.CONTROL)
262-
else:
263-
toktypes.append(gguf.TokenType.USER_DEFINED)
269+
if hasattr(tokenizer, "added_tokens_decoder"):
270+
if tokenizer.added_tokens_decoder[i].special:
271+
toktypes.append(gguf.TokenType.CONTROL)
272+
else:
273+
toktypes.append(gguf.TokenType.USER_DEFINED)
264274
else:
265275
tokens.append(reverse_vocab[i])
266276
toktypes.append(gguf.TokenType.NORMAL)
@@ -1068,17 +1078,22 @@ def write_tensors(self):
10681078

10691079
class Phi2Model(Model):
10701080
def set_gguf_parameters(self):
1071-
block_count = self.hparams["n_layer"]
1081+
block_count = get_key_opts(self.hparams, ["num_hidden_layers", "n_layer"])
1082+
1083+
rot_pct = get_key_opts(self.hparams, ["partial_rotary_factor"])
1084+
n_embd = get_key_opts(self.hparams, ["hidden_size", "n_embd"])
1085+
n_head = get_key_opts(self.hparams, ["num_attention_heads", "n_head"])
10721086

10731087
self.gguf_writer.add_name("Phi2")
1074-
self.gguf_writer.add_context_length(self.hparams["n_positions"])
1075-
self.gguf_writer.add_embedding_length(self.hparams["n_embd"])
1076-
self.gguf_writer.add_feed_forward_length(4 * self.hparams["n_embd"])
1088+
self.gguf_writer.add_context_length(get_key_opts(self.hparams, ["n_positions", "max_position_embeddings"]))
1089+
1090+
self.gguf_writer.add_embedding_length(n_embd)
1091+
self.gguf_writer.add_feed_forward_length(4 * n_embd)
10771092
self.gguf_writer.add_block_count(block_count)
1078-
self.gguf_writer.add_head_count(self.hparams["n_head"])
1079-
self.gguf_writer.add_head_count_kv(self.hparams["n_head"])
1080-
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
1081-
self.gguf_writer.add_rope_dimension_count(self.hparams["rotary_dim"])
1093+
self.gguf_writer.add_head_count(n_head)
1094+
self.gguf_writer.add_head_count_kv(n_head)
1095+
self.gguf_writer.add_layer_norm_eps(get_key_opts(self.hparams, ["layer_norm_epsilon", "layer_norm_eps"]))
1096+
self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head)
10821097
self.gguf_writer.add_file_type(self.ftype)
10831098
self.gguf_writer.add_add_bos_token(False)
10841099

gguf-py/gguf/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,9 @@ class MODEL_TENSOR(IntEnum):
389389
MODEL_TENSOR.OUTPUT,
390390
MODEL_TENSOR.ATTN_NORM,
391391
MODEL_TENSOR.ATTN_QKV,
392+
MODEL_TENSOR.ATTN_Q,
393+
MODEL_TENSOR.ATTN_K,
394+
MODEL_TENSOR.ATTN_V,
392395
MODEL_TENSOR.ATTN_OUT,
393396
MODEL_TENSOR.FFN_NORM,
394397
MODEL_TENSOR.FFN_DOWN,

gguf-py/gguf/tensor_mapping.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ class TensorNameMap:
191191
"transformer.h.{bid}.mlp.w1", # qwen
192192
"h.{bid}.mlp.c_fc", # gpt2
193193
"transformer.h.{bid}.mlp.fc1", # phi2
194+
"model.layers.{bid}.mlp.fc1", # phi2
194195
"model.layers.layers.{bid}.mlp.up_proj", # plamo
195196
),
196197

@@ -232,6 +233,7 @@ class TensorNameMap:
232233
"model.layers.{bid}.mlp.dense_4h_to_h", # persimmon
233234
"h.{bid}.mlp.c_proj", # gpt2
234235
"transformer.h.{bid}.mlp.fc2", # phi2
236+
"model.layers.{bid}.mlp.fc2", # phi2
235237
"model.layers.layers.{bid}.mlp.down_proj", # plamo
236238
),
237239

llama.cpp

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,9 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
574574
{ LLM_TENSOR_OUTPUT, "output" },
575575
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
576576
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
577+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
578+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
579+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
577580
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
578581
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
579582
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
@@ -3676,8 +3679,19 @@ static bool llm_load_tensors(
36763679
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
36773680
layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd});
36783681

3679-
layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
3680-
layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa});
3682+
layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, false);
3683+
layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, false);
3684+
3685+
if (layer.wqkv == nullptr) {
3686+
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
3687+
layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd});
3688+
3689+
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
3690+
layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa});
3691+
3692+
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
3693+
layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa});
3694+
}
36813695

36823696
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
36833697
layer.bo = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd});
@@ -5637,15 +5651,25 @@ struct llm_build_context {
56375651

56385652
// self-attention
56395653
{
5640-
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output);
5641-
cb(cur, "wqkv", il);
5654+
struct ggml_tensor * Qcur = nullptr;
5655+
struct ggml_tensor * Kcur = nullptr;
5656+
struct ggml_tensor * Vcur = nullptr;
56425657

5643-
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
5644-
cb(cur, "bqkv", il);
5658+
if (model.layers[il].wqkv) {
5659+
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, attn_norm_output);
5660+
cb(cur, "wqkv", il);
56455661

5646-
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
5647-
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
5648-
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
5662+
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
5663+
cb(cur, "bqkv", il);
5664+
5665+
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
5666+
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
5667+
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
5668+
} else {
5669+
Qcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
5670+
Kcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
5671+
Vcur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
5672+
}
56495673

56505674
cb(Qcur, "Qcur", il);
56515675
cb(Kcur, "Kcur", il);

0 commit comments

Comments
 (0)