Skip to content

llama: Attempt to add ModernBert #14014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open

Conversation

huydt84
Copy link
Contributor

@huydt84 huydt84 commented Jun 4, 2025

I don't know whether my implementation is correct or not

@github-actions github-actions bot added the python python script changes label Jun 4, 2025
@huydt84 huydt84 marked this pull request as draft June 4, 2025 15:27
@huydt84 huydt84 marked this pull request as ready for review June 4, 2025 15:36
@huydt84
Copy link
Contributor Author

huydt84 commented Jun 4, 2025

hparams.set_swa_pattern can't work properly with ModernBert

@huydt84 huydt84 marked this pull request as draft June 4, 2025 15:40
@huydt84
Copy link
Contributor Author

huydt84 commented Jun 4, 2025

The embedding result seems random and very low. There is something wrong with this

@huydt84 huydt84 marked this pull request as ready for review June 4, 2025 16:21
Copy link
Collaborator

@CISC CISC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete the files you added in models, we don't need them, just make sure test-tokenizer-0 succeeds with the GGUF.

@huydt84 huydt84 requested a review from CISC June 4, 2025 22:55
inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1);
cb(inpL, "inp_norm", -1);

auto * inp_attn = build_attn_inp_kv_unified_iswa();
Copy link
Member

@ggerganov ggerganov Jun 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably become:

Suggested change
auto * inp_attn = build_attn_inp_kv_unified_iswa();
auto * inp_attn = build_attn_inp_no_cache_iswa();

And add the corresponding mask logic in llama-graph. Special attention should be taken about how the SWA works for this model - i.e. is it symmetric or not:

# non-symmetric
token i attends to [i - n_swa, i]

# symmetric:
token i attends to [i - n_swa/2, i + n_swa/2]

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have to add the new arch here:

llama.cpp/src/llama-model.cpp

Lines 13195 to 13203 in 5a8ae30

switch (arch) {
case LLM_ARCH_BERT:
case LLM_ARCH_JINA_BERT_V2:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
case LLM_ARCH_WAVTOKENIZER_DEC:
{
res = nullptr;
} break;

To avoid creating a memory module (a.k.a. KV cache) for these models.

@huydt84 huydt84 requested a review from ggerganov June 5, 2025 13:55
@CISC
Copy link
Collaborator

CISC commented Jun 5, 2025

So, since vocab is BPE you need to add modern-bert vocab handling a few places:

tokenizer_pre == "roberta-bpe") {

Set correct attribute on [MASK] token, similarly to this:

llama.cpp/src/llama-vocab.cpp

Lines 2097 to 2105 in 9f47fa5

if (false
|| _contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})
|| _contains_any(general_arch, {"nomic-bert-moe"})
) {
if (token_to_id.count("<mask>") == 0) {
LLAMA_LOG_WARN("%s: Mask token is missing in vocab, please reconvert model!\n", __func__);
} else {
_set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
}

@CISC
Copy link
Collaborator

CISC commented Jun 5, 2025

The embedding result seems random and very low. There is something wrong with this

Yep, I also noticed the same with jina-reranker-v2, most likely the same issue, will investigate.

@CISC
Copy link
Collaborator

CISC commented Jun 6, 2025

So, since vocab is BPE you need to add modern-bert vocab handling a few places:

@huydt84 Don't forget this-^ it's important.

@CISC
Copy link
Collaborator

CISC commented Jun 6, 2025

The embedding result seems random and very low. There is something wrong with this

Yep, I also noticed the same with jina-reranker-v2, most likely the same issue, will investigate.

Will dig into this tonight/this weekend...

@huydt84
Copy link
Contributor Author

huydt84 commented Jun 6, 2025

So, since vocab is BPE you need to add modern-bert vocab handling a few places:

@huydt84 Don't forget this-^ it's important.

Thank you! I have just added it

@CISC
Copy link
Collaborator

CISC commented Jun 6, 2025

@huydt84 Don't forget this-^ it's important.

Thank you! I have just added it

The tokenizer_pre check is the most important one, please add that too. :)

Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add new enum llama_swa_type:

LLAMA_SWA_TYPE_SYMMETRIC  = 3,

inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1);
cb(inpL, "inp_norm", -1);

auto * inp_attn = build_attn_inp_no_cache_iswa();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is not an actual iSWA (interleaved SWA) model, we should use simply build_attn_inp_no_cache().

@@ -241,6 +249,7 @@ class llm_graph_input_attn_no_cache : public llm_graph_input_i {

const llama_hparams & hparams;
const llama_cparams & cparams;
const int n_swa; // Sliding window attention size (0 = disabled)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already available from the hparams - no need to duplicate it here.

Comment on lines 281 to 284
// Check if we're using sliding window attention
if (n_swa > 0) {
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch is actually non-causal attention + sliding window. So merge it with the existing implementation below.

@CISC
Copy link
Collaborator

CISC commented Jun 7, 2025

The embedding result seems random and very low. There is something wrong with this

Yep, I also noticed the same with jina-reranker-v2, most likely the same issue, will investigate.

Will dig into this tonight/this weekend...

Ok, the issue with jina-reranker-v2 was just that you have to apply sigmoid and normalize, guess that sigmoid option could be useful, @ggerganov?

That doesn't explain the issue with modernbert unfortunately (though I did try it for fun with Alibaba-NLP/gte-reranker-modernbert-base .. it seems to give reverse scores).

@huydt84
Copy link
Contributor Author

huydt84 commented Jun 8, 2025

@CISC cc: @ggerganov

I tried to do the embedding with various models, but the output results are barely changed among those attempts. Maybe the params load or inference graph is getting problems somewhere. Can you check that part?
This is the model implementation in Huggingface: https://github.com/huggingface/transformers/blob/v4.52.3/src/transformers/models/modernbert/modeling_modernbert.py

@CISC
Copy link
Collaborator

CISC commented Jun 8, 2025

So, I just noticed at least part of the problem:

llama.cpp/src/llama-graph.cpp

Lines 1567 to 1571 in 3ac6753

if (cls != nullptr && cls_b != nullptr) {
// classification head
// https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
cur = ggml_tanh(ctx0, cur);

We have cls, but not cls_b, so this has to be modified to handle that...

Comment on lines +6217 to +6240
// feed-forward network
ggml_tensor * ffn_up = build_lora_mm(model.layers[il].ffn_up, cur);
cb(ffn_up, "ffn_up", il);

int64_t split_point = ffn_up->ne[0] / 2;
ggml_tensor * output_ffn_up = ggml_cont(ctx0, ggml_view_2d(
ctx0, ffn_up, split_point,
ffn_up->ne[1], ffn_up->nb[1], 0
));
ggml_tensor * output_ffn_gate = ggml_cont(ctx0, ggml_view_2d(
ctx0, ffn_up, split_point,
ffn_up->ne[1], ffn_up->nb[1],
split_point * ggml_element_size(ffn_up)
));

// Apply activation function
output_ffn_up = ggml_gelu(ctx0, output_ffn_up);

// Element-wise multiplication
ggml_tensor * gated = ggml_mul(ctx0, output_ffn_up, output_ffn_gate);
cb(gated, "ffn_gated", il);

// Final projection
cur = build_lora_mm(model.layers[il].ffn_down, gated);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be merged into build_ffn as LLM_FFN_GEGLU.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably worth making a separate PR for visibility.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants