Skip to content

llama: Attempt to add ModernBert #14014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35":
# ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0
res = "minerva-7b"
if chkhsh == "a0b64b4385f123663873756336c085744376d015ff328bb1d901598f63c44152":
# ref: https://huggingface.co/answerdotai/ModernBERT-base
res = "modern-bert"

if res is None:
logger.warning("\n")
Expand Down Expand Up @@ -3932,6 +3935,34 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("ModernBert", "ModernBertForMaskedLM", "ModernBertForSequenceClassification")
class ModernBertModel(BertModel):
model_arch = gguf.MODEL_ARCH.MODERN_BERT

def set_vocab(self):
self._set_vocab_gpt2()
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(True)

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_sliding_window(self.hparams["local_attention"])
self.gguf_writer.add_rope_freq_base(self.hparams["global_rope_theta"])
self.gguf_writer.add_rope_freq_base_swa(self.hparams["local_rope_theta"])
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# These layers act as MLM head, so we don't need them
if name.startswith("decoder."):
return []

if name.startswith("model."):
name = name[6:]

return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("RobertaModel", "RobertaForSequenceClassification")
class RobertaModel(BertModel):
model_arch = gguf.MODEL_ARCH.BERT
Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", },
{"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", },
{"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", },
{"name": "modern-bert", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/answerdotai/ModernBERT-base", },
]

# some models are known to be broken upstream, so we will skip them as exceptions
Expand Down
19 changes: 19 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
FREQ_BASE = "{arch}.rope.freq_base"
FREQ_BASE_SWA = "{arch}.rope.freq_base_swa"
SCALING_TYPE = "{arch}.rope.scaling.type"
SCALING_FACTOR = "{arch}.rope.scaling.factor"
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
Expand Down Expand Up @@ -289,6 +290,7 @@ class MODEL_ARCH(IntEnum):
STARCODER = auto()
REFACT = auto()
BERT = auto()
MODERN_BERT = auto()
NOMIC_BERT = auto()
NOMIC_BERT_MOE = auto()
JINA_BERT_V2 = auto()
Expand Down Expand Up @@ -477,6 +479,7 @@ class MODEL_TENSOR(IntEnum):
ENC_FFN_UP = auto()
ENC_OUTPUT_NORM = auto()
CLS = auto() # classifier
CLS_NORM = auto() # classifier normalization
CLS_OUT = auto() # classifier output projection
CONV1D = auto()
CONVNEXT_DW = auto()
Expand Down Expand Up @@ -569,6 +572,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.STARCODER: "starcoder",
MODEL_ARCH.REFACT: "refact",
MODEL_ARCH.BERT: "bert",
MODEL_ARCH.MODERN_BERT: "modern-bert",
MODEL_ARCH.NOMIC_BERT: "nomic-bert",
MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe",
MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2",
Expand Down Expand Up @@ -757,6 +761,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up",
MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm",
MODEL_TENSOR.CLS: "cls",
MODEL_TENSOR.CLS_NORM: "cls.norm",
MODEL_TENSOR.CLS_OUT: "cls.output",
MODEL_TENSOR.CONV1D: "conv1d",
MODEL_TENSOR.CONVNEXT_DW: "convnext.{bid}.dw",
Expand Down Expand Up @@ -1047,6 +1052,20 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.CLS,
MODEL_TENSOR.CLS_OUT,
],
MODEL_ARCH.MODERN_BERT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_OUT_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ENC_OUTPUT_NORM,
MODEL_TENSOR.CLS,
MODEL_TENSOR.CLS_NORM,
MODEL_TENSOR.CLS_OUT,
],
MODEL_ARCH.NOMIC_BERT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,9 @@ def add_rope_dimension_sections(self, dims: Sequence[int]) -> None:
def add_rope_freq_base(self, value: float) -> None:
self.add_float32(Keys.Rope.FREQ_BASE.format(arch=self.arch), value)

def add_rope_freq_base_swa(self, value: float) -> None:
self.add_float32(Keys.Rope.FREQ_BASE_SWA.format(arch=self.arch), value)

def add_rope_scaling_type(self, value: RopeScalingType) -> None:
self.add_string(Keys.Rope.SCALING_TYPE.format(arch=self.arch), value.value)

Expand Down
14 changes: 14 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class TensorNameMap:
"model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414
"tok_embeddings", # llama-pth
"embeddings.word_embeddings", # bert nomic-bert
"embeddings.tok_embeddings", # modern-bert
"language_model.embedding.word_embeddings", # persimmon
"wte", # gpt2
"transformer.embd.wte", # phi2
Expand All @@ -42,6 +43,7 @@ class TensorNameMap:
MODEL_TENSOR.TOKEN_EMBD_NORM: (
"word_embeddings_layernorm", # bloom
"embeddings.LayerNorm", # bert
"embeddings.norm", # modern-bert
"emb_ln", # nomic-bert
"transformer.norm", # openelm
"rwkv.blocks.0.pre_ln", # rwkv
Expand Down Expand Up @@ -134,6 +136,7 @@ class TensorNameMap:
"rwkv.blocks.{bid}.ln1", # rwkv6
"model.layers.{bid}.ln1", # rwkv7
"model.layers.{bid}.input_layernorm", # llama4
"layers.{bid}.attn_norm", # modern-bert
),

# Attention norm 2
Expand Down Expand Up @@ -161,6 +164,7 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.qkv_proj", # phi3
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
"transformer.layers.{bid}.attn.qkv_proj", # openelm
"layers.{bid}.attn.Wqkv", # modern-bert
),

# Attention query
Expand Down Expand Up @@ -236,6 +240,7 @@ class TensorNameMap:
"transformer.layers.{bid}.attn.out_proj", # openelm
"transformer.h.{bid}.attn.attention.out_proj", # exaone
"model.layers.{bid}.self_attn.o_proj", # llama4
"layers.{bid}.attn.Wo", # modern-bert
),

# Attention output norm
Expand All @@ -245,6 +250,7 @@ class TensorNameMap:
"encoder.layers.{bid}.norm1", # nomic-bert
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
"layers.{bid}.mlp_norm" # modern-bert
),

MODEL_TENSOR.ATTN_POST_NORM: (
Expand Down Expand Up @@ -338,6 +344,7 @@ class TensorNameMap:
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
"transformer.h.{bid}.mlp.c_fc_1", # exaone
"model.layers.{bid}.feed_forward.up_proj", # llama4
"layers.{bid}.mlp.Wi" # modern-bert
),

MODEL_TENSOR.FFN_UP_EXP: (
Expand Down Expand Up @@ -420,6 +427,7 @@ class TensorNameMap:
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
"model.layers.h.{bid}.mlp.c_proj", # exaone
"model.layers.{bid}.feed_forward.down_proj", # llama4
"layers.{bid}.mlp.Wo" # modern-bert
),

MODEL_TENSOR.FFN_DOWN_EXP: (
Expand Down Expand Up @@ -830,12 +838,18 @@ class TensorNameMap:
# TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg
MODEL_TENSOR.ENC_OUTPUT_NORM: (
"encoder.final_layer_norm", # t5
"final_norm", # modern-bert
),

MODEL_TENSOR.CLS: (
"classifier", # jina
"classifier.dense", # roberta
"pre_classifier", # distillbert
"head.dense", # modern-bert
),

MODEL_TENSOR.CLS_NORM: (
"head.norm", # modern-bert
),

MODEL_TENSOR.CLS_OUT: (
Expand Down
20 changes: 20 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_STARCODER, "starcoder" },
{ LLM_ARCH_REFACT, "refact" },
{ LLM_ARCH_BERT, "bert" },
{ LLM_ARCH_MODERN_BERT, "modern-bert" },
{ LLM_ARCH_NOMIC_BERT, "nomic-bert" },
{ LLM_ARCH_NOMIC_BERT_MOE, "nomic-bert-moe" },
{ LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" },
Expand Down Expand Up @@ -148,6 +149,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
{ LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" },
{ LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" },
{ LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" },
{ LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" },
Expand Down Expand Up @@ -462,6 +464,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_CLS_OUT, "cls.output" },
},
},
{
LLM_ARCH_MODERN_BERT,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
{ LLM_TENSOR_CLS, "cls" },
{ LLM_TENSOR_CLS_NORM, "cls.norm" },
{ LLM_TENSOR_CLS_OUT, "cls.output" },
},
},
{
LLM_ARCH_NOMIC_BERT,
{
Expand Down Expand Up @@ -1572,6 +1591,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
{LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
Expand Down
3 changes: 3 additions & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ enum llm_arch {
LLM_ARCH_STARCODER,
LLM_ARCH_REFACT,
LLM_ARCH_BERT,
LLM_ARCH_MODERN_BERT,
LLM_ARCH_NOMIC_BERT,
LLM_ARCH_NOMIC_BERT_MOE,
LLM_ARCH_JINA_BERT_V2,
Expand Down Expand Up @@ -152,6 +153,7 @@ enum llm_kv {
LLM_KV_ROPE_DIMENSION_COUNT,
LLM_KV_ROPE_DIMENSION_SECTIONS,
LLM_KV_ROPE_FREQ_BASE,
LLM_KV_ROPE_FREQ_BASE_SWA,
LLM_KV_ROPE_SCALE_LINEAR,
LLM_KV_ROPE_SCALING_TYPE,
LLM_KV_ROPE_SCALING_FACTOR,
Expand Down Expand Up @@ -347,6 +349,7 @@ enum llm_tensor {
LLM_TENSOR_ENC_FFN_UP,
LLM_TENSOR_ENC_OUTPUT_NORM,
LLM_TENSOR_CLS,
LLM_TENSOR_CLS_NORM,
LLM_TENSOR_CLS_OUT,
LLM_TENSOR_CONV1D,
LLM_TENSOR_CONVNEXT_DW,
Expand Down
1 change: 1 addition & 0 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2027,6 +2027,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params
llama_set_param(model->output_norm_enc, param_filter, param_filter_ud);
llama_set_param(model->cls, param_filter, param_filter_ud);
llama_set_param(model->cls_b, param_filter, param_filter_ud);
llama_set_param(model->cls_norm, param_filter, param_filter_ud);
llama_set_param(model->cls_out, param_filter, param_filter_ud);
llama_set_param(model->cls_out_b, param_filter, param_filter_ud);

Expand Down
80 changes: 78 additions & 2 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,61 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {

void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
if (kq_mask) {
if (cparams.causal_attn) {
// Check if we're using sliding window attention
if (n_swa > 0) {
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch is actually non-causal attention + sliding window. So merge it with the existing implementation below.

const int64_t n_seqs = ubatch->n_seqs;
const int64_t n_stride = ubatch->n_tokens;
const int64_t half_n_swa = n_swa / 2;

GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
float * data = (float *) kq_mask->data;

// Implement symmetric sliding window attention
// token i attends to tokens [i - n_swa/2, i + n_swa/2]
for (int h = 0; h < 1; ++h) {
for (int s1 = 0; s1 < n_seqs; ++s1) {
const llama_seq_id seq_id = ubatch->seq_id[s1][0];

for (int j = 0; j < n_seq_tokens; ++j) {
const int32_t tj = s1*n_seq_tokens + j;
const int64_t pos_j = ubatch->pos[tj];

for (int s0 = 0; s0 < n_seqs; ++s0) {
for (int i = 0; i < n_seq_tokens; ++i) {
const int32_t ti = s0*n_seq_tokens + i;
float f = -INFINITY;

for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
if (ubatch->seq_id[s0][s] == seq_id) {
const int64_t pos_i = ubatch->pos[ti];
const int64_t pos_diff = pos_j - pos_i;

// Apply sliding window constraint
// [i - n_swa/2, i + n_swa/2]
if (pos_diff >= -half_n_swa && pos_diff <= half_n_swa) {
if (hparams.use_alibi) {
f = -std::abs(pos_diff);
} else {
f = 0.0f;
}
}
break;
}
}

data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
}
}

for (int i = n_tokens; i < n_stride; ++i) {
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY;
}
}
}
}
} else if (cparams.causal_attn) {
const int64_t n_kv = ubatch->n_tokens;
const int64_t n_tokens = ubatch->n_tokens;
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
Expand Down Expand Up @@ -1188,6 +1242,22 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
}

llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache_iswa() const {
// Use the sliding window size from hyperparameters
// If hparams.n_swa is 0, use a default value (128)
const int n_swa = hparams.n_swa > 0 ? hparams.n_swa : 128;

auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams, n_swa);

// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
ggml_set_input(inp->kq_mask);

inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;

return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp));
}

ggml_tensor * llm_graph_context::build_attn(
llm_graph_input_attn_no_cache * inp,
ggml_cgraph * gf,
Expand Down Expand Up @@ -1522,7 +1592,8 @@ void llm_graph_context::build_pooling(
ggml_tensor * cls,
ggml_tensor * cls_b,
ggml_tensor * cls_out,
ggml_tensor * cls_out_b) const {
ggml_tensor * cls_out_b,
ggml_tensor * cls_norm) const {
if (!cparams.embeddings) {
return;
}
Expand Down Expand Up @@ -1570,6 +1641,11 @@ void llm_graph_context::build_pooling(
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b);
cur = ggml_tanh(ctx0, cur);

if (cls_norm) {
// normalization head
cur = build_norm(cur, cls_norm, nullptr, LLM_NORM, 0);
}

// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
if (cls_out) {
Expand Down
Loading
Loading