-
As far as I understand, both np.savez and jnp.savez both convert jax.array() into np.array() via np.asarray() under the hood. Is there anything I am missing how to save larger jax.arrays() more efficiently? I am using jax entirely on CPUs if that is of help. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Can you say more about what your ultimate goal is? If your hope is to generate |
Beta Was this translation helpful? Give feedback.
jnp.savez
is just a thin wrapper aroundnp.savez
, and both work by converting a JAX array to NumPy via an implicitdevice_get
. and then callingnp.savez
. If you're using JAX entirely on CPU, then the array conversion should be virtually free (the memory buffer will be shared between JAX and NumPy) so any bottleneck you're seeing is purely on the NumPy side of things.Can you say more about what your ultimate goal is? If your hope is to generate
.npz
files containing your array content, I don't think there's any improvement to be had. But if, say, you're doing more general checkpointing within JAX workflows,…