-
I'm trying to jit trace a function that uses Example Code: import jax
import jax.numpy as jnp
randarray = jax.random.normal(jax.random.key(0), (10000,3))
# jax.jit(jnp.histogramdd, static_argnames=['bins'])(randarray,bins=20) # works
# jax.jit(jnp.histogramdd, static_argnames=['bins'])(randarray,bins=[10,20,30]) # ValueError: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'list'>, [10, 20, 30]. The error was: TypeError: unhashable type: 'list'
# jax.jit(jnp.histogramdd)(randarray,bins=[10,20,30]) # ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[] bins argument of histogram_bin_edges |
Beta Was this translation helpful? Give feedback.
Answered by
dfm
May 18, 2025
Replies: 1 comment 1 reply
-
There are a few options here. If you always want to use the same bins, something like: from functools import partial
jax.jit(partial(jnp.histogramdd, bins=[10,20,30]))(randarray) Another option would be to use a # jax.jit(jnp.histogramdd, static_argnames=['bins'])(randarray, bins=(10,20,30)) Hope this helps! |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
chjz1024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
There are a few options here. If you always want to use the same bins, something like:
Another option would be to use a
tuple
instead of alist
for your bins:# jax.jit(jnp.histogramdd, static_argnames=['bins'])(randarray, bins=(10,20,30))
Hope this helps!