How to handle static args in custom Jaxpr interpreters? #29072
Unanswered
huterguier
asked this question in
Q&A
Replies: 1 comment 4 replies
-
There's no absolutely general way to detect which arguments should be static and which should be dynamic: it's why def transform(fun, static_argnums=()):
... and then split your arguments into static and dynamic, closing over static arguments and passing only the dynamic arguments to |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I've encountered another issue following this guide on writing custom Jaxpr interpreters. Whenever I use the transformed function within a jitted function that has static args, tracing fails since my transformation doesn't know which incoming arguments are marked as static by the outer jit.
How can I determine which arguments are static? Is it enough to check which elements in args are hashable?
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions