re: usage and typing of array I/O with NumPy exports #28889
-
Hey. I'm trying to skirt around the (for me unnecessary) dependency on For now, I came up with the following, which loads the train/test contents into uint8 arrays via I'd like to assert that what comes out here are actually JAX arrays, which I'm not sure is the case, since some of JAX's array I/O APIs are re-exports from NumPy (see e.g. TL,DR: Is it correct to just go |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Also, NumPy's array serialization formats ( For your case, how hard is it to post-process everything with |
Beta Was this translation helpful? Give feedback.
jax.numpy.frombuffer
will always return JAX arrays.jax.numpy.load
will return JAX arrays from.npy
files, but not from.npz
files. This is because the object returned bynp.load
when passed a.npz
file is a non-trivial lazy view of the buffers on disk, and there was no easy way to hook into this in order to make it return JAX arrays. It's something that could probably be addressed with some effort if it were important, but nobody has told us before now that it's important.Also, NumPy's array serialization formats (
.npy
,.npz
) are not particularly well-suited for JAX, because they pre-suppose the particular set of dtypes that NumPy supports, and so there's no way to specifybfloat16
or o…