-
Notifications
You must be signed in to change notification settings - Fork 12k
llama: Attempt to add ModernBert #14014
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
huydt84
wants to merge
20
commits into
ggml-org:master
Choose a base branch
from
huydt84:huydt/mb
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 16 commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
045b1ac
llama: attempt to add modern-bert
huydt-bti 95f49d9
Merge branch 'master' into huydt/mb
huydt-bti eab776e
re-format and delete unused implementations
huydt-bti 7143840
overload set_swa_pattern for modern bert
huydt-bti 6aa1335
modern-bert doesn't have bias
huydt-bti 9e1179a
delete unnecessary files
huydt-bti fa23480
add build_attn_inp_no_cache_iswa with symmetric swa
huydt-bti a72cb3b
add modern-bert to llama_model::create_memory
huydt-bti adea1c9
fix lint
huydt-bti cfebb6e
access n_swa via hparams
huydt-bti 31e87e4
revert changes in convert script
huydt-bti 1004327
add set_vocab to modernbert convert class
huydt-bti 81f4797
Merge branch 'master' into huydt/mb
huydt-bti 03693fa
parmas-related fix
huydt-bti 2f5a72f
handle mask token in modern-bert bpe
huydt-bti 68f399e
add modern-bert to pre_type check
huydt-bti ad2a19a
change log warning when no mask token of modern-bert
huydt-bti c6b84e2
fix modern-bert swa logic
huydt-bti 6751e69
fix modern-bert class register
huydt-bti 8b794f9
Merge branch 'master' into huydt/mb
huydt-bti File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -278,7 +278,61 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { | |
|
||
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { | ||
if (kq_mask) { | ||
if (cparams.causal_attn) { | ||
// Check if we're using sliding window attention | ||
if (n_swa > 0) { | ||
const int64_t n_tokens = ubatch->n_tokens; | ||
const int64_t n_seq_tokens = ubatch->n_seq_tokens; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This branch is actually |
||
const int64_t n_seqs = ubatch->n_seqs; | ||
const int64_t n_stride = ubatch->n_tokens; | ||
const int64_t half_n_swa = n_swa / 2; | ||
|
||
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer)); | ||
float * data = (float *) kq_mask->data; | ||
|
||
// Implement symmetric sliding window attention | ||
// token i attends to tokens [i - n_swa/2, i + n_swa/2] | ||
for (int h = 0; h < 1; ++h) { | ||
for (int s1 = 0; s1 < n_seqs; ++s1) { | ||
const llama_seq_id seq_id = ubatch->seq_id[s1][0]; | ||
|
||
for (int j = 0; j < n_seq_tokens; ++j) { | ||
const int32_t tj = s1*n_seq_tokens + j; | ||
const int64_t pos_j = ubatch->pos[tj]; | ||
|
||
for (int s0 = 0; s0 < n_seqs; ++s0) { | ||
for (int i = 0; i < n_seq_tokens; ++i) { | ||
const int32_t ti = s0*n_seq_tokens + i; | ||
float f = -INFINITY; | ||
|
||
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) { | ||
if (ubatch->seq_id[s0][s] == seq_id) { | ||
const int64_t pos_i = ubatch->pos[ti]; | ||
const int64_t pos_diff = pos_j - pos_i; | ||
|
||
// Apply sliding window constraint | ||
// [i - n_swa/2, i + n_swa/2] | ||
if (pos_diff >= -half_n_swa && pos_diff <= half_n_swa) { | ||
if (hparams.use_alibi) { | ||
f = -std::abs(pos_diff); | ||
} else { | ||
f = 0.0f; | ||
} | ||
} | ||
break; | ||
} | ||
} | ||
|
||
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f; | ||
} | ||
} | ||
|
||
for (int i = n_tokens; i < n_stride; ++i) { | ||
data[h*(n_tokens*n_tokens) + tj*n_stride + i] = -INFINITY; | ||
} | ||
} | ||
} | ||
} | ||
} else if (cparams.causal_attn) { | ||
const int64_t n_kv = ubatch->n_tokens; | ||
const int64_t n_tokens = ubatch->n_tokens; | ||
const int64_t n_seq_tokens = ubatch->n_seq_tokens; | ||
|
@@ -1188,6 +1242,22 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con | |
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp)); | ||
} | ||
|
||
llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache_iswa() const { | ||
// Use the sliding window size from hyperparameters | ||
// If hparams.n_swa is 0, use a default value (128) | ||
const int n_swa = hparams.n_swa > 0 ? hparams.n_swa : 128; | ||
|
||
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams, n_swa); | ||
|
||
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch | ||
inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); | ||
ggml_set_input(inp->kq_mask); | ||
|
||
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask; | ||
|
||
return (llm_graph_input_attn_no_cache *) res->add_input(std::move(inp)); | ||
} | ||
|
||
ggml_tensor * llm_graph_context::build_attn( | ||
llm_graph_input_attn_no_cache * inp, | ||
ggml_cgraph * gf, | ||
|
@@ -1522,7 +1592,8 @@ void llm_graph_context::build_pooling( | |
ggml_tensor * cls, | ||
ggml_tensor * cls_b, | ||
ggml_tensor * cls_out, | ||
ggml_tensor * cls_out_b) const { | ||
ggml_tensor * cls_out_b, | ||
ggml_tensor * cls_norm) const { | ||
if (!cparams.embeddings) { | ||
return; | ||
} | ||
|
@@ -1570,6 +1641,11 @@ void llm_graph_context::build_pooling( | |
cur = ggml_add(ctx0, ggml_mul_mat(ctx0, cls, inp), cls_b); | ||
cur = ggml_tanh(ctx0, cur); | ||
|
||
if (cls_norm) { | ||
// normalization head | ||
cur = build_norm(cur, cls_norm, nullptr, LLM_NORM, 0); | ||
} | ||
|
||
// some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en | ||
// https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896 | ||
if (cls_out) { | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.