Skip to content

Feature request: Support for quantization #113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
JEM-Mosig opened this issue Apr 17, 2025 · 6 comments
Open

Feature request: Support for quantization #113

JEM-Mosig opened this issue Apr 17, 2025 · 6 comments
Labels
feature-request New feature or request

Comments

@JEM-Mosig
Copy link
Contributor

It'd be great if penzai would support model quantization out of the box. I know this is a lot of work to implement, but right now the lack of quantization support is the main reason why I wouldn't want to fine tune models with penzai.

@JEM-Mosig
Copy link
Contributor Author

JEM-Mosig commented Apr 19, 2025

One could use AQT for this, if penzai would expose the dot_general function somehow. But, e.g., Linear is implemented in terms of jnp.tensordot, which uses lax.dot_general under the hood, and NamedEinsum is implemented in terms of jnp.einsum, which does have an (experimental) _dot_general keyword argument, but it isn't exposed by Penzai.

@danieldjohnson danieldjohnson added the feature-request New feature or request label Apr 20, 2025
@danieldjohnson
Copy link
Collaborator

Agreed this would be a very useful feature!

I think it should be pretty easy to prototype something like this without needing to directly change Penzai's implementation, because Penzai is designed to make it easy to hot-swap out model components. One implementation strategy:

  • Define a new class AQTLinear with the same __call__ interface as pz.nn.Linear, but defined in terms of dot_general instead of jnp.tensordot
    • It could have a classmethod AQTLinear.from_linear(cls, orig: pz.nn.Linear, config: aqt_config.DotGeneral) -> AQTLinear that builds itself and adopts the parameters from the original pz.nn.Linear. (This is similar to how LowRankAdapter replaces a Linear, or how KVCachingAttention replaces an Attention.)
    • Perhaps AQTLinear.__call__ could be implemented by using jnp.einsum instead of jnp.tensordot
  • Similarly define a new class AQTNamedEinsum based on penzai's NamedEinsum
  • Use selectors to replace them, e.g.
    (
      pz.select(model)
      .at_instances_of(pz.nn.Linear)
      .apply(lambda lin: AQTLinear.from_base_linear(lin, aqt_config)
    )

If this works, it might make sense to add the AQTLinear/AQTNamedEinsum classes (and some helper functions) into Penzai, perhaps under penzai.toolshed.aqt. Then people who want to use AQT quantization could enable it with just a few extra lines.

(I probably won't have much bandwidth to experiment with this myself, but contributions are welcome!)

@JEM-Mosig
Copy link
Contributor Author

I had the same thought! Your naming is better, though :) I'll try to implement the AQT layers when I find the time. But I think doing so will involve a lot of code copying, which is not ideal. If Linear etc. would directly expose the dot_general, then AQTLinear could just build a Linear with a dot_general function that follows from the AQT config. In any case, I'll implement it first in the code-copy manner and see if it works.

@demoncoder-crypto
Copy link

Are you accepting contributions from community, I would love to work on this issue.

@JEM-Mosig
Copy link
Contributor Author

It looks like using AQT directly is a bit more tricky than I thought, as AQT objects carry around state for calibration and the AQT code generally seems to be in an unfinished and abandoned state. I'll see if I can implement some simple post-training quantization myself, but I can't guarantee that I'll find enough time to do so.

As inspiration, I think this section of the AQT Readme and maybe this outdated user guide for flax might be helpful.

@demoncoder-crypto I'm also just a community contributor, so I'm sure contributions would be welcome. Let me know if you make progress on this!

@JEM-Mosig
Copy link
Contributor Author

@demoncoder-crypto This looks like the successor to AQT and might be interesting to look into: https://github.com/google/qwix

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature-request New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants