Skip to content

Fix: Add missing attributes to hugging face conversion functions #116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions penzai/models/transformer/variants/gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,11 @@ def gpt_neox_from_huggingface_model(
"eos_token_id",
"_attn_implementation_autoset",
"head_dim",
"is_decoder",
"attention_probs_dropout_prob",
"hidden_dropout_prob",
"type_vocab_size",
"_name_or_path",
}
bad_attributes = {}
for k, v in hf_config_attributes.items():
Expand Down
3 changes: 3 additions & 0 deletions penzai/models/transformer/variants/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def llama_from_huggingface_model(
reference_attributes = transformers.LlamaConfig().to_dict()
handled_or_ignored_attributes = {
# Handled during conversion:
"hidden_act",
"hidden_size",
"intermediate_size",
"num_attention_heads",
Expand All @@ -80,8 +81,10 @@ def llama_from_huggingface_model(
"architectures",
"bos_token_id",
"eos_token_id",
"pad_token_id",
"_attn_implementation_autoset",
"head_dim",
"_name_or_path",
}
bad_attributes = {}
for k, v in hf_config_attributes.items():
Expand Down
24 changes: 14 additions & 10 deletions penzai/models/transformer/variants/llamalike_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class LlamalikeTransformerConfig:
mlp_hidden_dim: int
num_decoder_blocks: int
vocab_size: int
mlp_variant: Literal["geglu_approx", "swiglu"]
mlp_variant: Literal["gelu_exact", "geglu_approx", "swiglu", "silu", "relu"]
tie_embedder_and_logits: bool
rope_wavelength: float = 10_000
rms_norm_eps: float = 1e-6
Expand Down Expand Up @@ -147,14 +147,18 @@ def build_llamalike_feedforward(
Returns:
An instance of TransformerFeedForward containing the GELU MLP blocks.
"""
if config.mlp_variant == "geglu_approx":
# Approximate is already the default in JAX, but we specify it explicitly
# because defaults differ between JAX and PyTorch.
act_fn = functools.partial(jax.nn.gelu, approximate=True)
elif config.mlp_variant == "swiglu":
act_fn = jax.nn.silu
else:
raise ValueError(f"Unsupported MLP variant {config.mlp_variant}")
# Approximate GeLU is already the default in JAX, but we specify it explicitly
# because defaults differ between JAX and PyTorch.
# Alias for gelu and silu maintianed for backwards compatibility.
act_fn = {
"gelu": jax.nn.gelu,
"geglu_approx": functools.partial(jax.nn.gelu, approximate=True),
"gelu_exact": functools.partial(jax.nn.gelu, approximate=False),
"gelu_approx": functools.partial(jax.nn.gelu, approximate=True),
"swiglu": jax.nn.silu,
"silu": jax.nn.silu,
"relu": jax.nn.relu,
}[config.mlp_variant]

return model_parts.TransformerFeedForward([
pz.nn.BranchAndMultiplyTogether(
Expand Down Expand Up @@ -595,7 +599,7 @@ def llamalike_from_huggingface_model(
mlp_hidden_dim=hf_config.intermediate_size,
num_decoder_blocks=hf_config.num_hidden_layers,
vocab_size=hf_config.vocab_size,
mlp_variant="swiglu",
mlp_variant=hf_config.hidden_act,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this again, I realized there is a subtlety that makes this a bit complicated, and we probably need to handle it in a different way.

Specifically, in the current llamalike_common implementation, the mlp_variant is named based on the overall MLP design, e.g. "geglu" or "swiglu", and not the activation used inside that MLP, which would be "gelu" or "silu". These are related but not the same: a geglu MLP takes the product of a gelu and a linear layer, whereas a gelu MLP traditionally refers to just a normal MLP with a gelu activation and no separate linear layer.

On the other hand, in the HuggingFace transformers ecosystem, the hidden_act appears to refer specifically to the activation function used. Thus, when the model uses a geglu MLP, the hidden_act would be "gelu", and it's just assumed from context that there will be a separate linear layer.

Eventually it might make sense to try and switch conventions, but for now it seems simplest to keep the convention the same and avoid doing a bunch of refactoring, since in practice I think most of today's llama-like models use either geglu/gelu or swiglu/silu and not any other alternatives.

Would you mind making the following changes?

  • In the type for the mlp_variant arg, make it Literal["geglu_exact", "geglu_approx", "swiglu"] (so "geglu_exact" instead of "gelu_exact", and no silu/relu,
  • In build_llamalike_feedforward, map the "geglu_exact" MLP variant to the functools.partial(jax.nn.gelu, approximate=False) activation function, and don't allow anything other than "geglu_exact", "geglu_approx", "swiglu",
  • In llamalike_from_huggingface_model, have a separate mapping from HuggingFace's hidden_act to the Penzai mlp_variant, which would be something like
    {"silu": "swiglu", "gelu": "geglu_exact", "gelu_new": "geglu_approx"}[hf_config.hidden_act]
    
    (Note also that "gelu" needs to be mapped to the "geglu_exact" codepath because "gelu" in HuggingFace refers to the non-approximate version, whereas jax.nn.gelu is the approximate version by default.)

Thanks, and sorry for the complexity here!

rope_wavelength=hf_config.rope_theta,
tie_embedder_and_logits=False,
attention_type=attention_type,
Expand Down
7 changes: 7 additions & 0 deletions penzai/models/transformer/variants/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def mistral_from_huggingface_model(
reference_attributes = transformers.MistralConfig().to_dict()
handled_or_ignored_attributes = {
# Handled during conversion:
"hidden_act",
"hidden_size",
"intermediate_size",
"num_attention_heads",
Expand All @@ -86,6 +87,12 @@ def mistral_from_huggingface_model(
"architectures",
"_attn_implementation_autoset",
"head_dim",
"is_decoder",
"pad_token_id",
"attention_probs_dropout_prob",
"hidden_dropout_prob",
"type_vocab_size",
"_name_or_path",
}
bad_attributes = {}
for k, v in hf_config_attributes.items():
Expand Down
64 changes: 64 additions & 0 deletions tests/models/transformer_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,31 @@ class TransformerConsistencyTest(parameterized.TestCase):
)
def test_llama_consistency(self, num_attention_heads, num_key_value_heads):
cfg = transformers.LlamaConfig(
name_or_path="hf-internal-testing/tiny-random-LlamaForCausalLM",
vocab_size=11,
hidden_size=64,
intermediate_size=256,
num_hidden_layers=3,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
attention_bias=False,
attention_dropout=0.0,
bos_token_id=0,
eos_token_id=1,
hidden_act="silu",
initializer_range=0.02,
max_position_embeddings=2048,
mlp_bias=False,
model_type="llama",
pad_token_id=-1,
pretraining_tp=1,
rms_norm_eps=1e-06,
rope_scaling=None,
rope_theta=10000.0,
tie_word_embeddings=False,
torch_dtype="float32",
transformers_version="4.44.2",
use_cache=True,
Comment on lines +39 to +63
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: mind adding a comment that says where these config settings came from? e.g. # This config is based on pretrained model "..."

(Is it the config for "hf-internal-testing/tiny-random-LlamaForCausalLM"? I'm curious whether that's representative of the configs people actually use, I wonder if we could take the config args from e.g. meta-llama/Llama-3.1-8B instead but with a smaller number of layers / hidden size.)

Same comment also applies to the other modified tests below.

)

torch.manual_seed(0)
Expand Down Expand Up @@ -76,12 +95,33 @@ def test_llama_consistency(self, num_attention_heads, num_key_value_heads):
)
def test_mistral_consistency(self, num_attention_heads, num_key_value_heads):
cfg = transformers.MistralConfig(
name_or_path="hf-internal-testing/tiny-random-MistralForCausalLM",
is_decoder=True,
vocab_size=11,
hidden_size=64,
intermediate_size=256,
num_hidden_layers=3,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
attention_dropout=0.0,
attention_probs_dropout_prob=0.1,
bos_token_id=1,
eos_token_id=2,
head_dim=16,
hidden_act="silu",
hidden_dropout_prob=0.1,
initializer_range=0.02,
max_position_embeddings=512,
model_type="mistral",
pad_token_id=0,
rms_norm_eps=1e-06,
rope_theta=10000.0,
sliding_window=4096,
tie_word_embeddings=False,
torch_dtype="float32",
transformers_version="4.44.2",
type_vocab_size=16,
use_cache=True,
)

torch.manual_seed(0)
Expand Down Expand Up @@ -110,11 +150,35 @@ def test_mistral_consistency(self, num_attention_heads, num_key_value_heads):

def test_gpt_neox_consistency(self):
cfg = transformers.GPTNeoXConfig(
name_or_path="organization-name/model-name",
is_decoder=True,
vocab_size=11,
hidden_size=64,
intermediate_size=256,
num_hidden_layers=3,
num_attention_heads=4,
attention_probs_dropout_prob=0.1,
hidden_dropout_prob=0.1,
type_vocab_size=16,
hidden_act="gelu",
attention_bias=True,
attention_dropout=0.0,
bos_token_id=0,
classifier_dropout=0.1,
eos_token_id=0,
hidden_dropout=0.0,
initializer_range=0.02,
layer_norm_eps=1e-05,
max_position_embeddings=512,
model_type="gpt_neox",
rope_scaling=None,
rotary_emb_base=10000,
rotary_pct=0.25,
tie_word_embeddings=False,
torch_dtype="float32",
transformers_version="4.44.2",
use_cache=True,
use_parallel_residual=True,
)

torch.manual_seed(0)
Expand Down
35 changes: 17 additions & 18 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading