[shard_map + jax.lax.scan + vjp] How to reduce the number of communications #29608
Unanswered
PhilipVinc
asked this question in
Q&A
Replies: 1 comment
-
I don't see why pvary would not be allowed? ...
# ——— shard_map + lax.scan version ———
def vjp_scan(W, xs_shard, v_shard):
print(W.shape, xs_shard.shape, v_shard.shape)
# reshape each device’s slice into (num_batches, batch_size, …)
num_batches = xs_shard.shape[0] // batch_size
xs_chunks = xs_shard.reshape((num_batches, batch_size, M))
v_chunks = v_shard.reshape((num_batches, batch_size))
W = jax.lax.pvary(W, 's')
# scan body: compute local gradient chunk and add to accumulator
def body(acc, inputs):
xs_c, v_c = inputs
_, vjp_fn = jax.vjp(lambda w: model(w, xs_c), W)
gW = vjp_fn(v_c)[0]
return acc + gW, None
# run the scan over all chunks
grad_shard, _ = lax.scan(body, jax.lax.pvary(jnp.zeros_like(W), 's'), (xs_chunks, v_chunks))
# do one all‐reduce (psum) across the “s” devices
return jax.lax.psum(grad_shard, "s")
# wrap it up in a shard_map
vjp_sm = shard_map(
vjp_scan,
mesh=mesh,
in_specs=(P(None), P("s"), P("s")),
out_specs=P(None),
)
... |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I have a case where I have a function
f(W: f32[M], x: f32[N,M]) -> f32[N]
and I want to compute the vjpvjp(f, W, xs)(v : f32[N]) -> f32[M]
.To lower the memory cost, I can perform the vjp in K batches of size N/K instead of N and reduce the output.
The natural way to implement this is using jax.lax.scan.
I now want to combine this with sharding the
N
dimension, so the notation would bef(W: f32[M], x: f32[N@s,M]) -> f32[N@s]
, howeverjax.lax.scan
does not support sharding across the first dimension so I can resort to shard_map.My issue is the following: the
vjp
of the replicated parameterW: f32[N]
does an all-reduce (psum), correctly.In the MWE example below there is therefore an all-reduce (psum) for every iteration of the scan.
But this is doing more work than necessary: I don't need to all reduce at every iteration, but could accumulate the vjp on every device differently, and then all reduce the different shards.
However, I don't know how to code this in jax. I feel I should declare
W
asjax.lax.pvary
but this is not correct.Does anybody have some insight?
to maybe add some more context, if I look at the jaxpr of the
vjp_sm
function above I see that there is apsum_invariant
inside the scan loop. How can I change my code to remove it (and then add one outside the scan loop)?Beta Was this translation helpful? Give feedback.
All reactions