Skip to content

Initializer for Penzai's Linear layer does not support jax.nn.initializers #107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
Samarendra109 opened this issue Feb 7, 2025 · 3 comments

Comments

@Samarendra109
Copy link

I am trying to create a simple linear layer as follows,

from penzai import pz
import jax

embed_axis = "embed_axis"
head_axis = "head_axis"
num_heads = 4
embed_size = 10

layer = pz.nn.Linear.from_config(
            name="layer_name",
            init_base_rng=jax.random.key(42),
            input_axes={embed_axis: embed_size},
            output_axes={
                head_axis: num_heads, 
                f"{embed_axis}/{head_axis}": embed_size//num_heads
            },
            initializer=jax.nn.initializers.xavier_normal(),   
        )

I am getting the error,


TypeError Traceback (most recent call last)
in <cell line: 0>()
7 embed_size = 10
8
----> 9 layer = pz.nn.Linear.from_config(
10 name="layer_name",
11 init_base_rng=jax.random.key(42),

1 frames
/usr/local/lib/python3.11/dist-packages/penzai/nn/parameters.py in make_parameter(name, init_base_rng, initializer, metadata, *init_args, **init_kwargs)
110 metadata = {}
111 return variables.Parameter(
--> 112 value=initializer(
113 derive_param_key(init_base_rng, name), *init_args, **init_kwargs
114 ),

TypeError: variance_scaling..init() got an unexpected keyword argument 'input_axes'

I don't see anything in the documentation that can explain the cause of this error.

@danieldjohnson
Copy link
Collaborator

Ah, this could be better documented! The initializer for Penzai's Linear layer doesn't directly use JAX-style initializers like jax.nn.initializers.xavier_normal(). You should be able to use pz.nn.xavier_normal_initializer instead.

@Samarendra109
Copy link
Author

Would it be possible to add a feature that wraps JAX initializers into Penzai initializers? Currently, Penzai only has two implemented initializers for neural networks. I believe the implementation would be similar to the existing variance_scaling_initializer in Penzai.

If this can be considered a valid feature request, I’d be happy to contribute—provided the main contributors approve it in terms of feasibility and usefulness.

@danieldjohnson
Copy link
Collaborator

Sure, contributions are welcome!

@danieldjohnson danieldjohnson changed the title Issue in creating Linear Layer Initializer for Penzai's Linear layer does not support jax.nn.initializers Apr 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants