fori_loop and jacrev: it works ! but I thought it wasn't #14643
-
Well, @jit
def g(x):
def body(i,val):
v1, v2, v3, a = val
return v1+a[i], v2+a[i]**2, v3+a[i]**3, a
val_init=(0.,0.,0.,x)
N=x.shape[0]
res = jax.lax.fori_loop(0,N,body,val_init)
return jnp.array(list(res[:-1]))
Well when I read the doc, it is clear that when the fori_loop lower and upper index are known a priori (eg. static variables) then jax.lax.scan is used, Am I wrong?, or fori_loop has changed (I am running 0.3.25 on Colab) ? Is there a simple use-case to demonstrated that fori_loop jacrev is not working but scan does? Thanks |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 16 replies
-
Thanks for the question! N here is known at compile time, because array shapes are static. If you had used an array value in place of N, then it would be dynamic and your function would not be reverse-mode differentiable. |
Beta Was this translation helpful? Give feedback.
Thanks for the question! N here is known at compile time, because array shapes are static. If you had used an array value in place of N, then it would be dynamic and your function would not be reverse-mode differentiable.