Skip to content

Commit a366a3d

Browse files
saood06sszymczy
andauthored
Load all MoE experts during warmup and make warmup 1 token (#198)
* Load all MoE experts during warmup Co-authored-by: Stanisław Szymczyk <[email protected]> * Unify warmup to one token --------- Co-authored-by: Stanisław Szymczyk <[email protected]>
1 parent c12f73b commit a366a3d

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

common/common.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2169,8 +2169,10 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
21692169
if (bos != -1) {
21702170
tmp.push_back(bos);
21712171
}
2172-
tmp.push_back(eos);
2173-
2172+
else
2173+
{
2174+
tmp.push_back(eos);
2175+
}
21742176
if (llama_model_has_encoder(model)) {
21752177
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
21762178
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);

examples/llama-bench/llama-bench.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1586,7 +1586,7 @@ int main(int argc, char ** argv) {
15861586
if (params.warmup) {
15871587
if (t.n_prompt > 0) {
15881588
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
1589-
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
1589+
test_prompt(ctx, 1, 0, t.n_batch, t.n_threads);
15901590
}
15911591
if (t.n_gen > 0) {
15921592
test_gen(ctx, 1, 0, t.n_threads);

src/llama.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3784,7 +3784,7 @@ static size_t llama_model_max_nodes(const llama_model & /*model*/) {
37843784
// return 32768;
37853785
//}
37863786

3787-
return 8192;
3787+
return 65536;
37883788
}
37893789

37903790
struct llama_model_loader {
@@ -8879,7 +8879,8 @@ struct llm_build_context {
88798879
llama_context & lctx,
88808880
const llama_batch & batch,
88818881
const llm_build_cb & cb,
8882-
bool worst_case) :
8882+
bool worst_case,
8883+
bool warmup) :
88838884
model (lctx.model),
88848885
lctx (lctx),
88858886
hparams (model.hparams),
@@ -8897,7 +8898,7 @@ struct llm_build_context {
88978898
n_embd_head_v (hparams.n_embd_head_v),
88988899
n_embd_v_gqa (hparams.n_embd_v_gqa()),
88998900
n_expert (hparams.n_expert),
8900-
n_expert_used (hparams.n_expert_used),
8901+
n_expert_used (warmup ? hparams.n_expert : hparams.n_expert_used),
89018902
freq_base (cparams.rope_freq_base),
89028903
freq_scale (cparams.rope_freq_scale),
89038904
ext_factor (cparams.yarn_ext_factor),
@@ -14433,7 +14434,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
1443314434

1443414435
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
1443514436

14436-
struct llm_build_context llm(lctx, dummy, cb, false);
14437+
struct llm_build_context llm(lctx, dummy, cb, false, false);
1443714438

1443814439
llm.init();
1443914440

@@ -14450,7 +14451,7 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
1445014451

1445114452
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
1445214453

14453-
struct llm_build_context llm(lctx, dummy, cb, false);
14454+
struct llm_build_context llm(lctx, dummy, cb, false, false);
1445414455

1445514456
llm.init();
1445614457

@@ -14467,7 +14468,7 @@ static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) {
1446714468

1446814469
llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { };
1446914470

14470-
struct llm_build_context llm(lctx, dummy, cb, false);
14471+
struct llm_build_context llm(lctx, dummy, cb, false, false);
1447114472

1447214473
llm.init();
1447314474

@@ -14517,7 +14518,11 @@ static struct ggml_cgraph * llama_build_graph(
1451714518

1451814519
struct ggml_cgraph * result = NULL;
1451914520

14520-
struct llm_build_context llm(lctx, batch, cb, worst_case);
14521+
const llama_vocab * vocab = llama_get_vocab(&lctx);
14522+
llama_token bos = llama_token_bos_impl(*vocab);
14523+
llama_token eos = llama_token_eos_impl(*vocab);
14524+
bool is_warming_up = (batch.n_tokens == 1 && batch.token[0] == bos);
14525+
struct llm_build_context llm(lctx, batch, cb, worst_case, is_warming_up);
1452114526

1452214527
llm.init();
1452314528

0 commit comments

Comments
 (0)