You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I wrote a simple PDE solver in JAX, for the backwards 1-d heat equation.
The solver "rolls" a slice of values defined on a discretization of 1-d space backward through time, step-by-step, from the final time to time zero.
I used lax's scan to do this, and the scanning/rollback function involves a tridiagonal matrix inversion (I'm using a so-called "theta"-scheme).
My heat equation is parametrized by a few scalar parameters, and after computing the solution at time zero, I take the gradient of the solution with respect to the parameters.
I get accurate results (for both computed solution and its sensitivity) for values of 100 time steps and 200 space points.
However, if I increase the number of time steps, the computed solution becomes more accurate, but the gradient blows up.
Strangely, if I increase the number of space points (but keep the number of time steps the same), the computed solution becomes more accurate, but the gradient blows up again.
I know this is a vague description (and without any code to back it up), but does what am I trying to do raise any red flags just from a conceptual point of view?
I can imagine that increasing the number of time steps leads to instabilities in the gradient computation (the grad computation being a very long composition of operations), but I am confused that increasing the number of space points leads to the same effect?
Are there some inner workings related to JAX that immediately come to mind when reading this description?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
I wrote a simple PDE solver in JAX, for the backwards 1-d heat equation.
The solver "rolls" a slice of values defined on a discretization of 1-d space backward through time, step-by-step, from the final time to time zero.
I used
lax
'sscan
to do this, and the scanning/rollback function involves a tridiagonal matrix inversion (I'm using a so-called "theta"-scheme).My heat equation is parametrized by a few scalar parameters, and after computing the solution at time zero, I take the gradient of the solution with respect to the parameters.
I get accurate results (for both computed solution and its sensitivity) for values of 100 time steps and 200 space points.
However, if I increase the number of time steps, the computed solution becomes more accurate, but the gradient blows up.
Strangely, if I increase the number of space points (but keep the number of time steps the same), the computed solution becomes more accurate, but the gradient blows up again.
I know this is a vague description (and without any code to back it up), but does what am I trying to do raise any red flags just from a conceptual point of view?
I can imagine that increasing the number of time steps leads to instabilities in the gradient computation (the grad computation being a very long composition of operations), but I am confused that increasing the number of space points leads to the same effect?
Are there some inner workings related to JAX that immediately come to mind when reading this description?
Beta Was this translation helpful? Give feedback.
All reactions