Skip to content

[mutable-arrays] upgrade scan to work with partial_eval_jaxpr_fwd #29353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,14 +1002,16 @@ def partial_eval_jaxpr_nounits(
def partial_eval_jaxpr_nounits_fwd(
jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
instantiate: bool | Sequence[bool],
fwd: bool | Sequence[bool] = True,
) -> tuple[ClosedJaxpr, ClosedJaxpr, list[bool], list[AbstractValue], list[int | None]]:
instantiate = tuple(instantiate) if isinstance(instantiate, list) else instantiate
return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate, True)
fwd = tuple(fwd) if isinstance(fwd, list) else fwd
return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate, fwd)

@weakref_lru_cache
def _partial_eval_jaxpr_nounits(
jaxpr: ClosedJaxpr, in_unknowns: Sequence[bool],
instantiate: bool | Sequence[bool], fwd: bool):
instantiate: bool | Sequence[bool], fwd: bool | Sequence[bool]):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info)

cell = []
Expand All @@ -1023,13 +1025,19 @@ def fun(*known_vals_in):
f, TraceTag(), jaxpr.jaxpr.debug_info, instantiate).call_wrapped(in_pvals)
jaxpr_unknown = convert_constvars_jaxpr(jaxpr_unknown_)
out_unknowns = [not pval.is_known() for pval in out_pvals]
if not fwd:
if type(fwd) is bool and not fwd:
residuals_ = iter(residuals)
residuals = [next(residuals_) if f is None else known_vals_in[f]
for f in fwds]
assert next(residuals_, None) is None
fwds = [None] * len(fwds)
else:
if type(fwd) is tuple:
fwd_ = [f for f, uk in zip(fwd, in_unknowns) if not uk]
residuals_, residuals = iter(residuals), []
fwds = [residuals.append(next(residuals_)) if f is None else
residuals.append(known_vals_in[f]) if not fwd_[f] else
f for f in fwds]
fwds, residuals = _include_consts_in_fwds(jaxpr.consts, fwds, residuals)
res_avals = [core.get_aval(r) for r in residuals]
cell.append((out_unknowns, jaxpr_unknown, res_avals, fwds))
Expand Down
182 changes: 82 additions & 100 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
from jax._src.util import (
merge_lists, partition_list, safe_map, safe_zip, split_list,
split_list_checked, unzip2, weakref_lru_cache,)
from jax._src import xla_bridge as xb
from jax.tree_util import (
keystr, tree_flatten, tree_flatten_with_path, tree_map, tree_unflatten,
treedef_is_leaf)
Expand Down Expand Up @@ -807,10 +806,34 @@ def _const_to_intensive_res_forwarding(
tangent_jaxpr, [False] * num_nz + [i is not None for i in const_to_res])
return primal_jaxpr, tangent_jaxpr, intensive_res, new_in_fwd

def _scan_known_hoisting(jaxpr_known, const_tracers, num_res):
# To disable:
# return jaxpr_known, [], [False] * num_res, []
consts = [pe.PartialVal.known(t.pval.get_known())
if not isinstance(t.aval, state.AbstractRef)
else pe.PartialVal.unknown(t.aval)
for t in const_tracers if t.pval.is_known()]
others = _map(pe.PartialVal.unknown, jaxpr_known.in_avals[len(consts):])
num_known_outs = len(jaxpr_known.out_avals) - num_res
dbg = jaxpr_known.jaxpr.debug_info
with source_info_util.reset_name_stack():
jaxpr_known_, invar_pvals_out, known_consts = pe.trace_to_jaxpr_nounits(
lu.wrap_init(core.jaxpr_as_fun(jaxpr_known), debug_info=dbg),
consts + others, instantiate=[True] * num_known_outs + [False] * num_res)
jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ())
res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:]
which_hoisted = [pval.is_known() for pval in res_pvals]
hoisted_res = [pval.get_known() for pval in res_pvals if pval.is_known()]
mut_consts = [t.pval.get_known() for t in const_tracers
if t.pval.is_known() and isinstance(t.aval, state.AbstractRef)]
return jaxpr_known, [*known_consts, *mut_consts], which_hoisted, hoisted_res


def _scan_partial_eval(trace, *tracers, reverse: bool,
length: int, num_consts: int, num_carry: int,
jaxpr: core.ClosedJaxpr, linear: Sequence[bool],
unroll: int, _split_transpose: bool):
num_xs = len(jaxpr.in_avals) - num_consts - num_carry
num_ys = len(jaxpr.out_avals) - num_carry
unknowns = [not t.pval.is_known() for t in tracers]
const_uk, init_uk, xs_uk = split_list(unknowns, [num_consts, num_carry])
Expand All @@ -822,117 +845,76 @@ def _scan_partial_eval(trace, *tracers, reverse: bool,
carry_uk = init_uk
for _ in range(1 + len(carry_uk)):
unknowns = const_uk + carry_uk + xs_uk
jaxpr_known, jaxpr_unknown, out_uk, res_avals = pe.partial_eval_jaxpr_nounits(
jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys)
jaxpr_known, jaxpr_unknown, out_uk, res_avals, in_fwd_res = \
pe.partial_eval_jaxpr_nounits_fwd(
jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys,
fwd=[True] * num_consts + [False] * num_carry + [True] * num_xs)
carry_uk_out, ys_uk = split_list(out_uk, [num_carry])
if carry_uk_out == carry_uk:
break
else:
carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
else:
assert False, "Fixpoint not reached"
num_res = len(res_avals)
num_res_out, num_res_in = len(res_avals), len(in_fwd_res)
num_knowns_out = len(jaxpr_known.out_avals) - num_res_out
num_consts_known = num_consts - sum(const_uk)
num_carry_known = num_carry - sum(carry_uk)
del res_avals, carry_uk_out

# Instantiate those inputs which must be treated as unknown from the fixpoint.
tracers = tuple(trace.instantiate_const(t) if uk else t
for t, uk in zip(tracers, unknowns))

# The residual inputs and outputs of the jaxprs produced haven't yet been
# adapted to the scan calling convention; in particular, jaxpr_known has its
# residual outputs all at the end, meaning they're extensive outputs (which is
# fully general but may be wasteful for residuals which are loop-invariant)
# while jaxpr_unknown has its corresponding residual inputs at the front (just
# as a convention with partial_eval_jaxpr_nounits), making them constant
# inputs. To make them consistent, we move the residual inputs on
# jaxpr_unknown to the end, even though we may move some back in the sequel.
jaxpr_unknown = pe.move_binders_to_back(
jaxpr_unknown, [True] * num_res + [False] * sum(unknowns))

# At this point, all residuals are treated as extensive outputs of jaxpr_known
# (and extensive inputs to jaxpr_unknown). But residuals that are loop-
# invariant can be hoisted out of the scan, rather than letting them get
# broadcast (as in e.g. scanning multiplication by a constant matrix; we don't
# want to broadcast the matrix!). So, outside the loop we perform a partial
# evaluation with known 'const' inputs (but all other inputs unknown).
const_pvals = [pe.PartialVal.known(t.pval.get_known())
if not isinstance(t.aval, state.AbstractRef)
else pe.PartialVal.unknown(t.aval)
for t in tracers[:num_consts] if t.pval.is_known()]
other_pvals = [pe.PartialVal.unknown(aval)
for aval in jaxpr_known.in_avals[len(const_pvals):]]
with source_info_util.reset_name_stack():
jaxpr_known_, invar_pvals_out, jaxpr_known_consts = pe.trace_to_jaxpr_nounits(
lu.wrap_init(core.jaxpr_as_fun(jaxpr_known),
debug_info=jaxpr_known.jaxpr.debug_info),
const_pvals + other_pvals,
instantiate=[True] * (len(out_uk) - sum(out_uk)) + [False] * num_res)
jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ())
# The above trace_to_jaxpr_nounits call computed loop-invariant residuals
# (known values in invar_pvals_out) and also computed loop-invariant values
# needed by the new jaxpr_known (in jaxpr_known_consts, which replace the
# previous consts). We need to collect the computed intensive residuals, and
# move corresponding intensive residual binders in jaxpr_unknown to the front.
res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:]
intensive_res = [pval.get_known() for pval in res_pvals if pval.is_known()]
jaxpr_unknown = pe.move_binders_to_front(
jaxpr_unknown,
[False] * sum(unknowns) + [pval.is_known() for pval in res_pvals])
del const_pvals, other_pvals, invar_pvals_out, jaxpr_known_, res_pvals
# We use `jaxpr_known_consts` when we call scan_p.bind with jaxpr_known, and
# we use `intensive_res` when we build the jaxpr eqn with jaxpr_unknown.

# As another optimization, for any extensive inputs that are just forwarded to
# extensive outputs, to avoid a copy (which would be looping over
# dynamic-update-slice) we'd rather forward the input tracer/value. That means
# pruning some outputs from jaxpr_known here, and updating `out_flat` below.
fwds_known = pe._jaxpr_forwarding(jaxpr_known.jaxpr)
# Prune fwds_known to include only extensive input to extensive output.
fwds_known = [in_idx if out_idx >= num_carry - sum(carry_uk) and
in_idx is not None and
in_idx >= len(jaxpr_known_consts) + num_carry - sum(carry_uk)
else None for out_idx, in_idx in enumerate(fwds_known)]
# Drop any extensive output we can instead get by forwarding an input.
# TODO(mattjj): use pe.dce_jaxpr here, though need a fixpoint
jaxpr_known_, () = jaxpr_known.jaxpr, jaxpr_known.consts
jaxpr_known_ = jaxpr_known_.replace(
outvars=[x for x, i in zip(jaxpr_known_.outvars, fwds_known) if i is None])
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, ())
del jaxpr_known_
# We use `fwds_known` below when forming the output of scanning jaxpr_known.
tracers = [trace.instantiate_const(t) if uk else t
for t, uk in zip(tracers, unknowns)]
# Keep original known inputs, since in_fwd_res indexes into them.
orig_inputs = [*jaxpr_known.consts,
*[t.pval.get_known() for t in tracers if t.pval.is_known()]]

# At this point all non-forwarded residuals are treated as extensive outputs
# of jaxpr_known. Hoist out those that only depend on consts.
# Before: jaxpr_known: [*known_ins] -> [*known_outs, *non_fwd_res]
# After: jaxpr_known: [*known_consts, *known_ins] -> [*known_outs, *ext_res]
# where, modulo hoisted res not being broadcast, we have
# non_fwd_res = merge_lists(which_hoisted, ext_res, hoisted_res)
jaxpr_known, known_consts, which_hoisted, hoisted_res = \
_scan_known_hoisting(jaxpr_known, tracers[:num_consts], num_res_out)

# To make jaxpr_unknown match the scan calling convention, move to the back
# binders that don't correspond to hoisted or carry-forwarded residuals.
# Before: jaxpr_unknown: [*res, *unknown_ins] -> [*unkown_outs]
# After: jaxpr_unkonwn: [*int_res, *unknown_ins, *ext_res] -> [*unknown_outs]
num_unk_in = len(jaxpr_unknown.in_avals) - num_res_in
which_hoisted_ = iter(which_hoisted)
res_to_move = [not next(which_hoisted_) if f is None else
f >= num_consts_known + num_carry_known for f in in_fwd_res]
jaxpr_unknown = pe.move_binders_to_back(jaxpr_unknown,
res_to_move + [False] * num_unk_in)

# Run the known part of the scan (if it has any outputs or effects).
known_mutable_consts = [t.pval.get_known() for t in tracers[:num_consts]
if t.pval.is_known() and isinstance(t.aval, state.AbstractRef)]
known_inputs = (list(jaxpr_known_consts) + known_mutable_consts +
[t.pval.get_known() for t in tracers[num_consts:]
if t.pval.is_known()])
known_ins = [t.pval.get_known() for t in tracers[num_consts:] if t.pval.is_known()]
if not jaxpr_known.out_avals and not jaxpr_known.effects:
out_known = []
known_outs_ext_res = []
else:
linear_known = [False] * len(known_inputs) # conservative!
out_known = scan_p.bind(
*known_inputs, reverse=reverse, length=length, jaxpr=jaxpr_known,
num_consts=len(jaxpr_known_consts) + len(known_mutable_consts),
num_carry=num_carry - sum(carry_uk),
linear=tuple(linear_known), unroll=unroll,
linear_known = ([False] * len(known_consts) +
[l for l, uk in zip(linear, unknowns)[num_consts:] if not uk])
known_outs_ext_res = scan_p.bind(
*known_consts, *known_ins, jaxpr=jaxpr_known, reverse=reverse,
length=length, num_consts=len(known_consts),
num_carry=num_carry_known, linear=tuple(linear_known), unroll=unroll,
_split_transpose=_split_transpose)
del linear_known
# Complete the known output by filling in forwarded values using fwds_known.
out_known_iter = iter(out_known)
out_known = [next(out_known_iter) if f is None
else _maybe_put(known_inputs[f]) for f in fwds_known]
assert next(out_known_iter, None) is None
del known_inputs, out_known_iter

# Split known outputs from residuals.
out_known, extensive_res = split_list(out_known, [len(out_uk) - sum(out_uk)])
assert len(intensive_res) + len(extensive_res) == num_res
known_outs, ext_res = split_list(known_outs_ext_res, [num_knowns_out])

# Complete non_fwd_res and then res, then split to match binders.
non_fwd_res = merge_lists(which_hoisted, ext_res, hoisted_res)
non_fwd_res_ = iter(non_fwd_res)
res = [next(non_fwd_res_) if f is None else _maybe_put(orig_inputs[f])
for f in in_fwd_res]
assert next(non_fwd_res_, None) is None
int_res, ext_res = partition_list(res_to_move, res)

# Create input tracers for jaxpr_unknown bind.
unknown_inputs = [t for t in tracers if not t.pval.is_known()]
intensive_res = _map(trace.new_instantiated_const, intensive_res)
extensive_res = _map(trace.new_instantiated_const, extensive_res)
int_res = _map(trace.new_instantiated_const, int_res)
ext_res = _map(trace.new_instantiated_const, ext_res)
# Create output tracers for jaxpr_unknown bind, adapting extensive shapes.
carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)])
ys_avals = [core.unmapped_aval(length, 0, y_aval)
Expand All @@ -941,29 +923,29 @@ def _scan_partial_eval(trace, *tracers, reverse: bool,
for a in it.chain(carry_avals, ys_avals)]
del carry_avals, y_avals
# Create equation.
linear_unknown = tuple([False] * len(intensive_res) +
linear_unknown = tuple([False] * len(int_res) +
[l for l, uk in zip(linear, unknowns) if uk] +
[False] * len(extensive_res))
[False] * len(ext_res))
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
source = source_info_util.current().replace(name_stack=name_stack)
assert len(out_tracers) == len(jaxpr_unknown.out_avals)
eqn = pe.new_eqn_recipe(trace, [*intensive_res, *unknown_inputs, *extensive_res],
eqn = pe.new_eqn_recipe(trace, [*int_res, *unknown_inputs, *ext_res],
out_tracers, scan_p,
dict(reverse=reverse, length=length, unroll=unroll,
jaxpr=jaxpr_unknown, linear=linear_unknown,
num_consts=len(intensive_res) + sum(const_uk),
num_consts=len(int_res) + sum(const_uk),
num_carry=sum(carry_uk),
_split_transpose=_split_transpose),
jaxpr_unknown.effects, source)
for t in out_tracers: t.recipe = eqn

# Merge known and unknown outputs into final result.
return util.merge_lists(out_uk, out_known, out_tracers)
return util.merge_lists(out_uk, known_outs, out_tracers)

def _maybe_put(x):
if isinstance(x, np.ndarray):
aval = core.shaped_abstractify(x)
s = sharding.SingleDeviceSharding(xb.local_devices(backend='cpu')[0])
s = sharding.SingleDeviceSharding(pxla.get_default_device())
result_handler = pxla.global_aval_to_result_handler(aval, s, False)
return result_handler(pxla.shard_args([s], [None], [None], [x]))
else:
Expand Down
50 changes: 50 additions & 0 deletions tests/mutable_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,56 @@ def loss(x, y):
jax.grad(loss, (0,1))(x_top, y_top)
self.assertAllClose(dot_op.amax_history[:], jnp.zeros((5,)).at[:i+1].set(1.0), check_dtypes=False)

@parameterized.parameters([False, True])
def test_custom_vjp_grad_stats_plumbing_basic(self, jit):
@jax.jit
def primal(grads_ref, x): # note: jit-abstracted!
x = jnp.sin(x)
x = stash_grads(grads_ref, x)
x = jnp.sin(x)
x = stash_grads(grads_ref, x) # ignored, order-preserved
return x

@jax.custom_vjp
def stash_grads(grads_ref, x):
return x
def stash_grads_fwd(grads_ref, x):
return x, grads_ref
def stash_grads_bwd(grads_ref, g):
grads_ref[...] = g
return None, g
stash_grads.defvjp(stash_grads_fwd, stash_grads_bwd)

grads_ref = core.mutable_array(jnp.float32(0.))
jax.grad(primal, 1)(grads_ref, jnp.float32(1.0))
self.assertAllClose(grads_ref[...], jnp.cos(jnp.sin(1.)), check_dtypes=False)

@parameterized.parameters([False, True])
def test_custom_vjp_grad_stats_plumbing_scan(self, jit):
@jax.jit
def primal(grads_ref, x): # note: jit-abstracted!
def body(x, _):
x = jnp.sin(x)
x = stash_grads(grads_ref, x)
x = jnp.sin(x)
return x, ()
x, () = jax.lax.scan(body, x, None, length=1)
return x

@jax.custom_vjp
def stash_grads(grads_ref, x):
return x
def stash_grads_fwd(grads_ref, x):
return x, grads_ref
def stash_grads_bwd(grads_ref, g):
grads_ref[...] = g
return None, g
stash_grads.defvjp(stash_grads_fwd, stash_grads_bwd)

grads_ref = core.mutable_array(jnp.float32(0.))
jax.grad(primal, argnums=1)(grads_ref, jnp.float32(1.0))
self.assertAllClose(grads_ref[...], jnp.cos(jnp.sin(1.)), check_dtypes=False)


@jtu.with_config(jax_mutable_array_checks=True)
class MutableArrayErrorsTest(jtu.JaxTestCase):
Expand Down
Loading