-
Notifications
You must be signed in to change notification settings - Fork 62
Fix: Add missing attributes to hugging face conversion functions #116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,12 +36,31 @@ class TransformerConsistencyTest(parameterized.TestCase): | |
) | ||
def test_llama_consistency(self, num_attention_heads, num_key_value_heads): | ||
cfg = transformers.LlamaConfig( | ||
name_or_path="hf-internal-testing/tiny-random-LlamaForCausalLM", | ||
vocab_size=11, | ||
hidden_size=64, | ||
intermediate_size=256, | ||
num_hidden_layers=3, | ||
num_attention_heads=num_attention_heads, | ||
num_key_value_heads=num_key_value_heads, | ||
attention_bias=False, | ||
attention_dropout=0.0, | ||
bos_token_id=0, | ||
eos_token_id=1, | ||
hidden_act="silu", | ||
initializer_range=0.02, | ||
max_position_embeddings=2048, | ||
mlp_bias=False, | ||
model_type="llama", | ||
pad_token_id=-1, | ||
pretraining_tp=1, | ||
rms_norm_eps=1e-06, | ||
rope_scaling=None, | ||
rope_theta=10000.0, | ||
tie_word_embeddings=False, | ||
torch_dtype="float32", | ||
transformers_version="4.44.2", | ||
use_cache=True, | ||
Comment on lines
+39
to
+63
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: mind adding a comment that says where these config settings came from? e.g. (Is it the config for "hf-internal-testing/tiny-random-LlamaForCausalLM"? I'm curious whether that's representative of the configs people actually use, I wonder if we could take the config args from e.g. Same comment also applies to the other modified tests below. |
||
) | ||
|
||
torch.manual_seed(0) | ||
|
@@ -76,12 +95,33 @@ def test_llama_consistency(self, num_attention_heads, num_key_value_heads): | |
) | ||
def test_mistral_consistency(self, num_attention_heads, num_key_value_heads): | ||
cfg = transformers.MistralConfig( | ||
name_or_path="hf-internal-testing/tiny-random-MistralForCausalLM", | ||
is_decoder=True, | ||
vocab_size=11, | ||
hidden_size=64, | ||
intermediate_size=256, | ||
num_hidden_layers=3, | ||
num_attention_heads=num_attention_heads, | ||
num_key_value_heads=num_key_value_heads, | ||
attention_dropout=0.0, | ||
attention_probs_dropout_prob=0.1, | ||
bos_token_id=1, | ||
eos_token_id=2, | ||
head_dim=16, | ||
hidden_act="silu", | ||
hidden_dropout_prob=0.1, | ||
initializer_range=0.02, | ||
max_position_embeddings=512, | ||
model_type="mistral", | ||
pad_token_id=0, | ||
rms_norm_eps=1e-06, | ||
rope_theta=10000.0, | ||
sliding_window=4096, | ||
tie_word_embeddings=False, | ||
torch_dtype="float32", | ||
transformers_version="4.44.2", | ||
type_vocab_size=16, | ||
use_cache=True, | ||
) | ||
|
||
torch.manual_seed(0) | ||
|
@@ -110,11 +150,35 @@ def test_mistral_consistency(self, num_attention_heads, num_key_value_heads): | |
|
||
def test_gpt_neox_consistency(self): | ||
cfg = transformers.GPTNeoXConfig( | ||
name_or_path="organization-name/model-name", | ||
is_decoder=True, | ||
vocab_size=11, | ||
hidden_size=64, | ||
intermediate_size=256, | ||
num_hidden_layers=3, | ||
num_attention_heads=4, | ||
attention_probs_dropout_prob=0.1, | ||
hidden_dropout_prob=0.1, | ||
type_vocab_size=16, | ||
hidden_act="gelu", | ||
attention_bias=True, | ||
attention_dropout=0.0, | ||
bos_token_id=0, | ||
classifier_dropout=0.1, | ||
eos_token_id=0, | ||
hidden_dropout=0.0, | ||
initializer_range=0.02, | ||
layer_norm_eps=1e-05, | ||
max_position_embeddings=512, | ||
model_type="gpt_neox", | ||
rope_scaling=None, | ||
rotary_emb_base=10000, | ||
rotary_pct=0.25, | ||
tie_word_embeddings=False, | ||
torch_dtype="float32", | ||
transformers_version="4.44.2", | ||
use_cache=True, | ||
use_parallel_residual=True, | ||
) | ||
|
||
torch.manual_seed(0) | ||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at this again, I realized there is a subtlety that makes this a bit complicated, and we probably need to handle it in a different way.
Specifically, in the current
llamalike_common
implementation, themlp_variant
is named based on the overall MLP design, e.g. "geglu" or "swiglu", and not the activation used inside that MLP, which would be "gelu" or "silu". These are related but not the same: a geglu MLP takes the product of a gelu and a linear layer, whereas a gelu MLP traditionally refers to just a normal MLP with a gelu activation and no separate linear layer.On the other hand, in the HuggingFace transformers ecosystem, the
hidden_act
appears to refer specifically to the activation function used. Thus, when the model uses a geglu MLP, thehidden_act
would be "gelu", and it's just assumed from context that there will be a separate linear layer.Eventually it might make sense to try and switch conventions, but for now it seems simplest to keep the convention the same and avoid doing a bunch of refactoring, since in practice I think most of today's llama-like models use either geglu/gelu or swiglu/silu and not any other alternatives.
Would you mind making the following changes?
mlp_variant
arg, make itLiteral["geglu_exact", "geglu_approx", "swiglu"]
(so "geglu_exact" instead of "gelu_exact", and no silu/relu,build_llamalike_feedforward
, map the "geglu_exact" MLP variant to thefunctools.partial(jax.nn.gelu, approximate=False)
activation function, and don't allow anything other than "geglu_exact", "geglu_approx", "swiglu",llamalike_from_huggingface_model
, have a separate mapping from HuggingFace'shidden_act
to the Penzaimlp_variant
, which would be something likejax.nn.gelu
is the approximate version by default.)Thanks, and sorry for the complexity here!