You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have been an avid numba user for a long time, since it usually gave me a good speed improvement with little change to my code base. But since it only supports Nvidia GPUs, I am thinking of switching to JAX and using more mathy integrations in JAX.
I am mostly working with hyperspectral image data and need to process data for each pixel. I have already started a small project with JAX with lumax, and even got some working code. Now I am struggling with performance a bit, and I am not sure if I am doing things the right way, as they are intended.
I have the main function amsa_scalar found in lumax.models.hapke.amsa and I derive two new functions amsa_vector and amsa_image using jax.vamp:
This works nicely for images, and I get the correct result.
My problem lies in inverting this model, i.e., I have a reflectance and I need to determine the first parameter of the function.
I am currently using optimistix (see lumax/inverse.py) to do the inversion and obtain the parameter, but it is quite slow and memory hungry (see patrick-kidger/optimistix#70).
I tried doing this using numba, and got much better results (even tho I had to implement the LM algorithm using numba). The main idea was to jit all function, and iterate over all pixels in parallel and apply the optimization per pixel. This gave me the best performance, as moving the prange statement further down gave me more overhead per itteration...
Even tho I have a working model, I think there is much more performance to be squeezed here, so I wanted to see if someone can help me with this! I am still inexperienced with JAX, so any help is appreciated.
Thins I also wanted to test are:
Using a custom derivative with custom_jvm, since I implemented the analytical derivative of the AMSA model
Exploring shard_map a bit more... I experimented a bit locally with it, and it has a nicer interface, but it is substantially slower and seems not to be compatible with optimistix right of the box.
Different LM algorithms implemented in JAX.
Parallelize the algorithm in a different way, but JAX doesn't give explicit control of parallel execution like numba does with prange, but is implicit with vectorization using jax.vmap.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
I have been an avid numba user for a long time, since it usually gave me a good speed improvement with little change to my code base. But since it only supports Nvidia GPUs, I am thinking of switching to JAX and using more mathy integrations in JAX.
I am mostly working with hyperspectral image data and need to process data for each pixel. I have already started a small project with JAX with lumax, and even got some working code. Now I am struggling with performance a bit, and I am not sure if I am doing things the right way, as they are intended.
I have the main function
amsa_scalar
found inlumax.models.hapke.amsa
and I derive two new functionsamsa_vector
andamsa_image
usingjax.vamp
:This works nicely for images, and I get the correct result.
My problem lies in inverting this model, i.e., I have a reflectance and I need to determine the first parameter of the function.
I am currently using optimistix (see
lumax/inverse.py
) to do the inversion and obtain the parameter, but it is quite slow and memory hungry (see patrick-kidger/optimistix#70).I tried doing this using numba, and got much better results (even tho I had to implement the LM algorithm using numba). The main idea was to jit all function, and iterate over all pixels in parallel and apply the optimization per pixel. This gave me the best performance, as moving the
prange
statement further down gave me more overhead per itteration...Even tho I have a working model, I think there is much more performance to be squeezed here, so I wanted to see if someone can help me with this! I am still inexperienced with JAX, so any help is appreciated.
Thins I also wanted to test are:
shard_map
a bit more... I experimented a bit locally with it, and it has a nicer interface, but it is substantially slower and seems not to be compatible with optimistix right of the box.prange
, but is implicit with vectorization usingjax.vmap
.Any help here is appreciated!
Beta Was this translation helpful? Give feedback.
All reactions