Skip to content

How to jit trace jnp.histogramdd with different bin sizes on different dimensions? #28814

Answered by dfm
chjz1024 asked this question in Q&A
Discussion options

You must be logged in to vote

There are a few options here. If you always want to use the same bins, something like:

from functools import partial

jax.jit(partial(jnp.histogramdd, bins=[10,20,30]))(randarray)

Another option would be to use a tuple instead of a list for your bins:

# jax.jit(jnp.histogramdd, static_argnames=['bins'])(randarray, bins=(10,20,30))

Hope this helps!

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@chjz1024
Comment options

Answer selected by chjz1024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants