Replies: 2 comments 1 reply
-
Beta Was this translation helpful? Give feedback.
-
I am also stuck with the same problem of poor performance with I was gutsy and used jax/jax/_src/interpreters/pxla.py Line 160 in eef1f6c I don't think I can follow it from here as I am not at all proficient in XLA. But what is strange to me is that the data is internally passed around in a numpy ndarray. In my case, I wrote a pre-process pipeline with Anyway, there is something wrong with This post is translated from Japanese to English at DeepL. Translated with DeepL.com (free version) |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
The context for this questions revolves around transferring a single unpinned array on host of shape and type
f32[4096, 8192]
to multiple devicesgpu0..7
. The difficulty is that the runtime required for this action seems higher than anticipated.The broader scope is the perform an optimizer step across the sharded data on device.
Going off this tutorial, the approach is to shard the data and the pass it to the jitted function. However, the transfer takes time.
I tried to get an idea of a reasonable level of performance to expect. Unfortunately, I don't have access to the H1/200s I usually work on on weekends so this was conducted on Colab with a TPUv2x8, but the story told by the metrics remains the same (just a magnitude off).
Output:
3.05 ms ± 46.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Output:
91.5 ms ± 2.04 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
What does not make sense here to me is why the sharded transfer would take 30 times long as the single transfer. Naively, I would imagine the "poor man's" version could not possibly be better:
Output:
26.9 ms ± 659 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Now this is more in line with what I would expect! But I would also imagine that Jax should have some sort of advantage when using
device_put
with a sharding argument (e.g., it could pipeline/overlay the transfers to different devices).One idea I have is to attempt the following recipe:
gpu_buffers = [jax.empty((512, 8192)) for _ in range(8)]
)jax.make_array_from_single_device_arrays((4096, 8192), sharding, gpu_buffers)
)To be clear, my questions are now:
device_put
so much slower than expected?params
(the model is really small)? For now, it just lives oncuda:0
.Beta Was this translation helpful? Give feedback.
All reactions