Asynchronous dispatch to specific devices / CPU cores #28996
Unanswered
markus7800
asked this question in
Q&A
Replies: 1 comment
-
For CPU, a workaround that I consider would be some sort of multiprocessing setup, where
However, this does not feel like an intended way of using JAX. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
I know that the primary target of JAX is single-program multiple data applications.
However, I would like to kindly ask for advice on whether it is possible to use JAX in multiple-program multiple data applications.
Let's consider following benchmark:
If we run this, for instance, on a Google Colab v2-8 TPU instance, we get following measurements:
So far, so expected.
First Question: If we run this on a multi-core CPU (with
--xla_force_host_platform_device_count
set appropriately),async 2 devices
will have the same runtime asasync 1 device
with only one CPU core being utilised. Is there a way to dispatch mutliple computations asynchronously on specific CPU cores to get back the speed-up?Note that we cannot use
pmap
at all because the output shape off
andg
are different.A similar issue arises if we want to split up the computation to specific TPUs. Let's say compute
f
with the first 4 TPUs andg
with the remaining TPUs:Second Question: Is there a way to dispatch the computations in
fg_pmap_async_devices
asynchronously and do them in parallel?Again, we cannot use
pmap
because the output shape off
andg
are different.General advice on the use case I have in mind is also very much appreciated.
Beta Was this translation helpful? Give feedback.
All reactions