Skip to content

Saving large jax.arrays via jnp.savez or np.savez is slow - How to speed it up? #28809

Answered by jakevdp
isaac-tes asked this question in Q&A
Discussion options

You must be logged in to vote

jnp.savez is just a thin wrapper around np.savez, and both work by converting a JAX array to NumPy via an implicit device_get. and then calling np.savez. If you're using JAX entirely on CPU, then the array conversion should be virtually free (the memory buffer will be shared between JAX and NumPy) so any bottleneck you're seeing is purely on the NumPy side of things.

Is there anything I am missing how to save larger jax.arrays() more efficiently?

Can you say more about what your ultimate goal is? If your hope is to generate .npz files containing your array content, I don't think there's any improvement to be had. But if, say, you're doing more general checkpointing within JAX workflows,…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by isaac-tes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants