Skip to content

[mutable-arrays] re-land #29353 #29421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,14 +992,16 @@ def partial_eval_jaxpr_nounits(
def partial_eval_jaxpr_nounits_fwd(
jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
instantiate: bool | Sequence[bool],
fwd: bool | Sequence[bool] = True,
) -> tuple[ClosedJaxpr, ClosedJaxpr, list[bool], list[AbstractValue], list[int | None]]:
instantiate = tuple(instantiate) if isinstance(instantiate, list) else instantiate
return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate, True)
fwd = tuple(fwd) if isinstance(fwd, list) else fwd
return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate, fwd)

@weakref_lru_cache
def _partial_eval_jaxpr_nounits(
jaxpr: ClosedJaxpr, in_unknowns: Sequence[bool],
instantiate: bool | Sequence[bool], fwd: bool):
instantiate: bool | Sequence[bool], fwd: bool | Sequence[bool]):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info)

cell = []
Expand All @@ -1013,13 +1015,19 @@ def fun(*known_vals_in):
f, TraceTag(), jaxpr.jaxpr.debug_info, instantiate).call_wrapped(in_pvals)
jaxpr_unknown = convert_constvars_jaxpr(jaxpr_unknown_)
out_unknowns = [not pval.is_known() for pval in out_pvals]
if not fwd:
if type(fwd) is bool and not fwd:
residuals_ = iter(residuals)
residuals = [next(residuals_) if f is None else known_vals_in[f]
for f in fwds]
assert next(residuals_, None) is None
fwds = [None] * len(fwds)
else:
if type(fwd) is tuple:
fwd_ = [f for f, uk in zip(fwd, in_unknowns) if not uk]
residuals_, residuals = iter(residuals), []
fwds = [residuals.append(next(residuals_)) if f is None else
residuals.append(known_vals_in[f]) if not fwd_[f] else
f for f in fwds]
fwds, residuals = _include_consts_in_fwds(jaxpr.consts, fwds, residuals)
res_avals = [core.get_aval(r) for r in residuals]
cell.append((out_unknowns, jaxpr_unknown, res_avals, fwds))
Expand Down
190 changes: 85 additions & 105 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from jax._src import util
from jax._src.api_util import (
_check_no_aliased_ref_args, _check_no_aliased_closed_over_refs)
from jax._src.core import ShapedArray
from jax._src.core import ShapedArray, typeof
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
Expand Down Expand Up @@ -809,6 +809,27 @@ def _const_to_intensive_res_forwarding(
tangent_jaxpr, [False] * num_nz + [i is not None for i in const_to_res])
return primal_jaxpr, tangent_jaxpr, intensive_res, new_in_fwd

def _scan_known_hoisting(jaxpr_known, known_consts, num_res):
# To disable:
# return jaxpr_known, known_consts, [False] * num_res, []

consts = [pe.PartialVal.unknown(a) if isinstance(a := typeof(c), state.AbstractRef)
else pe.PartialVal.known(c) for c in known_consts]
others = _map(pe.PartialVal.unknown, jaxpr_known.in_avals[len(consts):])
num_known_outs = len(jaxpr_known.out_avals) - num_res
with source_info_util.reset_name_stack():
jaxpr_known_, pvals_out, new_known_consts = pe.trace_to_jaxpr_nounits(
lu.wrap_init(core.jaxpr_as_fun(jaxpr_known),
debug_info=jaxpr_known.jaxpr.debug_info),
consts + others, instantiate=[True] * num_known_outs + [False] * num_res)
jaxpr_known = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_known_))
res_pvals = pvals_out[num_known_outs:]
which_hoisted = [pval.is_known() for pval in res_pvals]
hoisted_res = [pval.get_known() for pval in res_pvals if pval.is_known()]
mut_consts = [c for c in known_consts if isinstance(typeof(c), state.AbstractRef)]
return jaxpr_known, [*new_known_consts, *mut_consts], which_hoisted, hoisted_res


def _scan_partial_eval(trace, *tracers, reverse: bool,
length: int, num_consts: int, num_carry: int,
jaxpr: core.ClosedJaxpr, linear: Sequence[bool],
Expand All @@ -819,148 +840,107 @@ def _scan_partial_eval(trace, *tracers, reverse: bool,

# Fixpoint computation of which carry elements are unknown. Each iteration
# promotes at least one carry to unknown. We need at most len(carry)
# iterations, but we need one last iteration to prepare the jaxpr based on the
# final carry_uk.
# iterations to decide carry_uk, plus one to prepare the jaxpr.
carry_uk = init_uk
# Don't allow forwarding from the carry or numpy.ndarrays.
fwd = [(i < num_consts or i >= num_consts + num_carry) and
not isinstance(t.pval.get_known(), np.ndarray) for i, t in enumerate(tracers)]
for _ in range(1 + len(carry_uk)):
unknowns = const_uk + carry_uk + xs_uk
jaxpr_known, jaxpr_unknown, out_uk, res_avals = pe.partial_eval_jaxpr_nounits(
jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys)
jaxpr_known, jaxpr_unknown, out_uk, res_avals, in_fwd_res = \
pe.partial_eval_jaxpr_nounits_fwd(
jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys, fwd=fwd)
carry_uk_out, ys_uk = split_list(out_uk, [num_carry])
if carry_uk_out == carry_uk:
break
else:
carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
else:
assert False, "Fixpoint not reached"
num_res = len(res_avals)
num_res_out, num_res_in = len(res_avals), len(in_fwd_res)
num_knowns_out = len(jaxpr_known.out_avals) - num_res_out
num_consts_known = num_consts - sum(const_uk)
num_carry_known = num_carry - sum(carry_uk)
del res_avals, carry_uk_out

# Instantiate those inputs which must be treated as unknown from the fixpoint.
tracers = tuple(trace.instantiate_const(t) if uk else t
for t, uk in zip(tracers, unknowns))

# The residual inputs and outputs of the jaxprs produced haven't yet been
# adapted to the scan calling convention; in particular, jaxpr_known has its
# residual outputs all at the end, meaning they're extensive outputs (which is
# fully general but may be wasteful for residuals which are loop-invariant)
# while jaxpr_unknown has its corresponding residual inputs at the front (just
# as a convention with partial_eval_jaxpr_nounits), making them constant
# inputs. To make them consistent, we move the residual inputs on
# jaxpr_unknown to the end, even though we may move some back in the sequel.
tracers = [trace.instantiate_const(t) if uk else t
for t, uk in zip(tracers, unknowns)]
known_ins = [t.pval.get_known() for t in tracers if t.pval.is_known()]
unknown_ins = [t for t in tracers if not t.pval.is_known()]

# At this point all non-forwarded residuals are treated as extensive outputs
# of jaxpr_known. Hoist out those that only depend on consts.
# Before: jaxpr_known: [*known_ins] -> [*known_outs, *non_fwd_res]
# After: jaxpr_known: [*known_consts_, *known_ins] -> [*known_outs, *ext_res]
# where, modulo hoisted res not being broadcast, we have
# non_fwd_res = merge_lists(which_hoisted, ext_res, hoisted_res)
known_consts, known_ins = split_list(known_ins, [num_consts_known])
jaxpr_known, known_consts_, which_hoisted, hoisted_res = \
_scan_known_hoisting(jaxpr_known, known_consts, num_res_out)

# To make jaxpr_unknown match the scan calling convention, move to the back
# binders that don't correspond to hoisted or const-forwarded residuals.
# Before: jaxpr_unknown: [*res, *unknown_ins] -> [*unkown_outs]
# After: jaxpr_unkonwn: [*int_res, *unknown_ins, *ext_res] -> [*unknown_outs]
num_unk_in = len(jaxpr_unknown.in_avals) - num_res_in
which_hoisted_ = iter(which_hoisted)
res_to_move = [not next(which_hoisted_) if f is None else
f >= len(jaxpr.consts) + num_consts_known + num_carry_known
for f in in_fwd_res]
assert next(which_hoisted_, None) is None
jaxpr_unknown = pe.move_binders_to_back(
jaxpr_unknown, [True] * num_res + [False] * sum(unknowns))

# At this point, all residuals are treated as extensive outputs of jaxpr_known
# (and extensive inputs to jaxpr_unknown). But residuals that are loop-
# invariant can be hoisted out of the scan, rather than letting them get
# broadcast (as in e.g. scanning multiplication by a constant matrix; we don't
# want to broadcast the matrix!). So, outside the loop we perform a partial
# evaluation with known 'const' inputs (but all other inputs unknown).
const_pvals = [pe.PartialVal.known(t.pval.get_known())
if not isinstance(t.aval, state.AbstractRef)
else pe.PartialVal.unknown(t.aval)
for t in tracers[:num_consts] if t.pval.is_known()]
other_pvals = [pe.PartialVal.unknown(aval)
for aval in jaxpr_known.in_avals[len(const_pvals):]]
with source_info_util.reset_name_stack():
jaxpr_known_, invar_pvals_out, jaxpr_known_consts = pe.trace_to_jaxpr_nounits(
lu.wrap_init(core.jaxpr_as_fun(jaxpr_known),
debug_info=jaxpr_known.jaxpr.debug_info),
const_pvals + other_pvals,
instantiate=[True] * (len(out_uk) - sum(out_uk)) + [False] * num_res)
jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ())
# The above trace_to_jaxpr_nounits call computed loop-invariant residuals
# (known values in invar_pvals_out) and also computed loop-invariant values
# needed by the new jaxpr_known (in jaxpr_known_consts, which replace the
# previous consts). We need to collect the computed intensive residuals, and
# move corresponding intensive residual binders in jaxpr_unknown to the front.
res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:]
intensive_res = [pval.get_known() for pval in res_pvals if pval.is_known()]
jaxpr_unknown = pe.move_binders_to_front(
jaxpr_unknown,
[False] * sum(unknowns) + [pval.is_known() for pval in res_pvals])
del const_pvals, other_pvals, invar_pvals_out, jaxpr_known_, res_pvals
# We use `jaxpr_known_consts` when we call scan_p.bind with jaxpr_known, and
# we use `intensive_res` when we build the jaxpr eqn with jaxpr_unknown.

# As another optimization, for any extensive inputs that are just forwarded to
# extensive outputs, to avoid a copy (which would be looping over
# dynamic-update-slice) we'd rather forward the input tracer/value. That means
# pruning some outputs from jaxpr_known here, and updating `out_flat` below.
fwds_known = pe._jaxpr_forwarding(jaxpr_known.jaxpr)
# Prune fwds_known to include only extensive input to extensive output.
fwds_known = [in_idx if out_idx >= num_carry - sum(carry_uk) and
in_idx is not None and
in_idx >= len(jaxpr_known_consts) + num_carry - sum(carry_uk)
else None for out_idx, in_idx in enumerate(fwds_known)]
# Drop any extensive output we can instead get by forwarding an input.
# TODO(mattjj): use pe.dce_jaxpr here, though need a fixpoint
jaxpr_known_, () = jaxpr_known.jaxpr, jaxpr_known.consts
jaxpr_known_ = jaxpr_known_.replace(
outvars=[x for x, i in zip(jaxpr_known_.outvars, fwds_known) if i is None])
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, ())
del jaxpr_known_
# We use `fwds_known` below when forming the output of scanning jaxpr_known.
jaxpr_unknown, res_to_move + [False] * num_unk_in)

# Run the known part of the scan (if it has any outputs or effects).
known_mutable_consts = [t.pval.get_known() for t in tracers[:num_consts]
if t.pval.is_known() and isinstance(t.aval, state.AbstractRef)]
known_inputs = (list(jaxpr_known_consts) + known_mutable_consts +
[t.pval.get_known() for t in tracers[num_consts:]
if t.pval.is_known()])
linear_known, linear_unknown = partition_list(unknowns, linear)
if not jaxpr_known.out_avals and not jaxpr_known.effects:
out_known = []
known_outs_ext_res = []
else:
linear_known = [False] * len(known_inputs) # conservative!
out_known = scan_p.bind(
*known_inputs, reverse=reverse, length=length, jaxpr=jaxpr_known,
num_consts=len(jaxpr_known_consts) + len(known_mutable_consts),
num_carry=num_carry - sum(carry_uk),
linear=tuple(linear_known), unroll=unroll,
linear_known = [False] * len(jaxpr_known.in_avals) # TODO conservative
assert len(known_consts_) + len(known_ins) == len(jaxpr_known.in_avals)
known_outs_ext_res = scan_p.bind(
*known_consts_, *known_ins, jaxpr=jaxpr_known, reverse=reverse,
length=length, num_consts=len(known_consts_),
num_carry=num_carry_known, linear=(*linear_known,), unroll=unroll,
_split_transpose=_split_transpose)
del linear_known
# Complete the known output by filling in forwarded values using fwds_known.
out_known_iter = iter(out_known)
out_known = [next(out_known_iter) if f is None
else _maybe_put(known_inputs[f]) for f in fwds_known]
assert next(out_known_iter, None) is None
del known_inputs, out_known_iter

# Split known outputs from residuals.
out_known, extensive_res = split_list(out_known, [len(out_uk) - sum(out_uk)])
assert len(intensive_res) + len(extensive_res) == num_res
known_outs, ext_res = split_list(known_outs_ext_res, [num_knowns_out])

# Complete non_fwd_res and then res, then split to match binders.
non_fwd_res = merge_lists(which_hoisted, ext_res, hoisted_res)
non_fwd_res_ = iter(non_fwd_res)
res = [next(non_fwd_res_) if f is None
else [*jaxpr.consts, *known_consts, *known_ins][f] for f in in_fwd_res]
assert next(non_fwd_res_, None) is None
int_res, ext_res = partition_list(res_to_move, res)

# Create input tracers for jaxpr_unknown bind.
unknown_inputs = [t for t in tracers if not t.pval.is_known()]
intensive_res = _map(trace.new_instantiated_const, intensive_res)
extensive_res = _map(trace.new_instantiated_const, extensive_res)
int_res = _map(trace.new_instantiated_const, int_res)
ext_res = _map(trace.new_instantiated_const, ext_res)
# Create output tracers for jaxpr_unknown bind, adapting extensive shapes.
carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)])
ys_avals = [core.unmapped_aval(length, 0, y_aval)
for y_aval in y_avals]
ys_avals = [core.unmapped_aval(length, 0, y_aval) for y_aval in y_avals]
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
for a in it.chain(carry_avals, ys_avals)]
del carry_avals, y_avals
# Create equation.
linear_unknown = tuple([False] * len(intensive_res) +
[l for l, uk in zip(linear, unknowns) if uk] +
[False] * len(extensive_res))
linear_unknown = [False] * len(int_res) + linear_unknown + [False] * len(ext_res)
assert len(linear_unknown) == len(jaxpr_unknown.in_avals)
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
source = source_info_util.current().replace(name_stack=name_stack)
assert len(out_tracers) == len(jaxpr_unknown.out_avals)
eqn = pe.new_eqn_recipe(trace, [*intensive_res, *unknown_inputs, *extensive_res],
eqn = pe.new_eqn_recipe(trace, [*int_res, *unknown_inputs, *ext_res],
out_tracers, scan_p,
dict(reverse=reverse, length=length, unroll=unroll,
jaxpr=jaxpr_unknown, linear=linear_unknown,
num_consts=len(intensive_res) + sum(const_uk),
jaxpr=jaxpr_unknown, linear=(*linear_unknown,),
num_consts=len(int_res) + sum(const_uk),
num_carry=sum(carry_uk),
_split_transpose=_split_transpose),
jaxpr_unknown.effects, source)
for t in out_tracers: t.recipe = eqn

# Merge known and unknown outputs into final result.
return util.merge_lists(out_uk, out_known, out_tracers)
return util.merge_lists(out_uk, known_outs, out_tracers)

def _maybe_put(x):
if isinstance(x, np.ndarray):
Expand Down
91 changes: 91 additions & 0 deletions tests/lax_control_flow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import contextlib
from functools import partial
import itertools
import math
import operator
import re
import unittest
Expand Down Expand Up @@ -3395,6 +3396,96 @@ def g(x):
jaxpr = jax.make_jaxpr(g)(jnp.arange(3.))
self.assertLen(jaxpr.eqns[0].params["jaxpr"].jaxpr.outvars, 1)

@jtu.sample_product(
seed=range(6),
num_rule_consts=range(6),
num_const_fwds=range(6),
num_carry_fwds=range(6),
num_input_fwds=range(6),
)
@jtu.run_on_devices("cpu")
def test_scan_vjp_forwarding_correctness(
self,
seed,
num_rule_consts,
num_const_fwds,
num_carry_fwds,
num_input_fwds):
# Unlike test_scan_forwarding_correctness, which tests forwarding in the
# scan traceable, this test covers forwarding logic related to residuals in
# the scan partial eval / vjp rule. So 'forwards' refer to residuals that
# will be forwarded.

# We use a custom_jvp where the jvp rule introduces consts to populate
# jaxpr.consts in _scan_partial_eval's input.
@jax.custom_jvp
def foo(x):
return 3. * x
@foo.defjvp
def foo_jvp(primals, tangents):
(x,), (x_dot,) = primals, tangents
if num_rule_consts:
coeff = sum([jnp.array(np.ones(3) / num_rule_consts) for _ in range(num_rule_consts)]) # noqa: C419
else:
coeff = 1.
return foo(x), jnp.prod(coeff) * x_dot

num_const = num_const_fwds + 2
num_carry = num_carry_fwds + 4
num_xs = num_input_fwds + 2
num_ys = num_xs + 1

rng = np.random.RandomState(seed)
carry_perm = rng.permutation(num_carry)
carry_iperm = np.argsort(carry_perm)

xs_perm = rng.permutation(num_xs)
ys_perm = rng.permutation(num_ys)
f = np.arange(num_xs)
f = [f[i] if idx < num_input_fwds else None for idx, i in enumerate(xs_perm)]
f += [None]
in_fwd = [f[i] for i in ys_perm]

body_consts = [jnp.array(rng.randn(3)) for _ in range(num_const)]
init_vals = list(map(jnp.array, rng.uniform(size=(num_carry, 3))))

def body_fun(c, x):
c = [c[i] for i in carry_iperm]

const_fwds, const_dont_fwd = split_list(body_consts, [num_const_fwds])
z = sum(const_dont_fwd)

carry_fwds, carry_dont_fwd = split_list(c, [num_const_fwds])
carry_fwds = [math.prod([x, x, *const_fwds, z]) for x in carry_fwds]
carry_dont_fwd = [jnp.sin(x) * sum(jnp.sum(c) for c in body_consts)
for x in carry_dont_fwd]
new_c_perm = [*carry_fwds, *carry_dont_fwd]
new_c = [new_c_perm[i] for i in carry_perm]
new_c = [foo(new_c[0]), *new_c[1:]]

x = [x[i] for i in xs_perm]
x_fwd, x_dont_fwd = split_list(x, [num_input_fwds])
x_fwd = [x * x for x in x_fwd]
x_dont_fwd = [jnp.cos(x) * sum(jnp.sum(c) for c in body_consts)
for x in x_dont_fwd]
y = [*x_fwd, *x_dont_fwd, 0]
y = [y[i] for i in ys_perm]

return new_c, y

xs = list(map(jnp.array, rng.uniform(size=(num_xs, 2))))

(final, outs), vjp = jax.vjp(partial(jax.lax.scan, body_fun), init_vals, xs)
init_vals_bar, xs_bar = vjp((final, outs))

with jax.disable_jit():
(final_ref, outs_ref), vjp = jax.vjp(partial(jax.lax.scan, body_fun), init_vals, xs)
init_vals_bar_ref, xs_bar_ref = vjp((final, outs))

self.assertAllClose(final, final_ref, check_dtypes=False)
self.assertAllClose(outs, outs_ref, check_dtypes=False)
self.assertAllClose(xs_bar, xs_bar_ref, check_dtypes=False)


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())
Loading
Loading