Skip to content

Commit 5fc6dbd

Browse files
committed
model : adapt gemma3
ggml-ci
1 parent 226ff01 commit 5fc6dbd

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

src/llama-model.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7354,11 +7354,11 @@ struct llm_build_gemma2 : public llm_graph_context {
73547354
};
73557355

73567356
struct llm_build_gemma3 : public llm_graph_context {
7357-
llm_build_gemma3(const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
7357+
llm_build_gemma3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
73587358
const int64_t n_embd_head_k = hparams.n_embd_head_k;
73597359

7360-
struct ggml_tensor * cur;
7361-
struct ggml_tensor * inpL;
7360+
ggml_tensor * cur;
7361+
ggml_tensor * inpL;
73627362

73637363
inpL = build_inp_embd(model.tok_embd);
73647364

@@ -7369,10 +7369,10 @@ struct llm_build_gemma3 : public llm_graph_context {
73697369
}
73707370

73717371
// inp_pos - contains the positions
7372-
struct ggml_tensor * inp_pos = build_inp_pos();
7372+
ggml_tensor * inp_pos = build_inp_pos();
73737373

73747374
// TODO: is causal == true correct? might need some changes
7375-
auto inp_attn = build_attn_inp_kv_self(true, true);
7375+
auto * inp_attn = build_attn_inp_kv_unified(true, true);
73767376

73777377
// "5-to-1 interleaved attention"
73787378
// 5 layers of local attention followed by 1 layer of global attention
@@ -7381,8 +7381,8 @@ struct llm_build_gemma3 : public llm_graph_context {
73817381
for (int il = 0; il < n_layer; ++il) {
73827382
const bool is_sliding = il % sliding_window_pattern < (sliding_window_pattern - 1);
73837383

7384-
const float freq_base_l = is_sliding ? 10000.0f : freq_base;
7385-
const float freq_scale_l = is_sliding ? 1.0f : freq_scale;
7384+
const float freq_base_l = is_sliding ? 10000.0f : freq_base;
7385+
const float freq_scale_l = is_sliding ? 1.0f : freq_scale;
73867386

73877387
// norm
73887388
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
@@ -7391,13 +7391,13 @@ struct llm_build_gemma3 : public llm_graph_context {
73917391
// self-attention
73927392
{
73937393
// compute Q and K and RoPE them
7394-
struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
7394+
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
73957395
cb(Qcur, "Qcur", il);
73967396

7397-
struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
7397+
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
73987398
cb(Kcur, "Kcur", il);
73997399

7400-
struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
7400+
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
74017401
cb(Vcur, "Vcur", il);
74027402

74037403
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens);
@@ -7420,7 +7420,7 @@ struct llm_build_gemma3 : public llm_graph_context {
74207420
ext_factor, attn_factor, beta_fast, beta_slow);
74217421
cb(Kcur, "Kcur", il);
74227422

7423-
cur = build_attn(inp_attn.get(), gf,
7423+
cur = build_attn(inp_attn, gf,
74247424
model.layers[il].wo, NULL,
74257425
Qcur, Kcur, Vcur, nullptr, hparams.f_attention_scale, il);
74267426
}
@@ -7432,12 +7432,12 @@ struct llm_build_gemma3 : public llm_graph_context {
74327432

74337433
if (il == n_layer - 1) {
74347434
// skip computing output for unused tokens
7435-
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7435+
ggml_tensor * inp_out_ids = build_inp_out_ids();
74367436
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
74377437
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
74387438
}
74397439

7440-
struct ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
7440+
ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
74417441
cb(sa_out, "sa_out", il);
74427442

74437443
cur = build_norm(sa_out,
@@ -11017,7 +11017,7 @@ llm_graph_result_ptr llama_model::build_graph(
1101711017
} break;
1101811018
case LLM_ARCH_GEMMA3:
1101911019
{
11020-
llm = std::make_unique<llm_build_gemma3>(params, gf);
11020+
llm = std::make_unique<llm_build_gemma3>(*this, params, gf);
1102111021
} break;
1102211022
case LLM_ARCH_STARCODER2:
1102311023
{

0 commit comments

Comments
 (0)