Skip to content

Commit f6ea9ac

Browse files
committed
[mutable-arrays] re-land #29353
1 parent 45e61d8 commit f6ea9ac

File tree

4 files changed

+232
-109
lines changed

4 files changed

+232
-109
lines changed

jax/_src/interpreters/partial_eval.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -992,14 +992,16 @@ def partial_eval_jaxpr_nounits(
992992
def partial_eval_jaxpr_nounits_fwd(
993993
jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
994994
instantiate: bool | Sequence[bool],
995+
fwd: bool | Sequence[bool] = True,
995996
) -> tuple[ClosedJaxpr, ClosedJaxpr, list[bool], list[AbstractValue], list[int | None]]:
996997
instantiate = tuple(instantiate) if isinstance(instantiate, list) else instantiate
997-
return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate, True)
998+
fwd = tuple(fwd) if isinstance(fwd, list) else fwd
999+
return _partial_eval_jaxpr_nounits(jaxpr, tuple(unknowns), instantiate, fwd)
9981000

9991001
@weakref_lru_cache
10001002
def _partial_eval_jaxpr_nounits(
10011003
jaxpr: ClosedJaxpr, in_unknowns: Sequence[bool],
1002-
instantiate: bool | Sequence[bool], fwd: bool):
1004+
instantiate: bool | Sequence[bool], fwd: bool | Sequence[bool]):
10031005
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr), debug_info=jaxpr.jaxpr.debug_info)
10041006

10051007
cell = []
@@ -1013,13 +1015,19 @@ def fun(*known_vals_in):
10131015
f, TraceTag(), jaxpr.jaxpr.debug_info, instantiate).call_wrapped(in_pvals)
10141016
jaxpr_unknown = convert_constvars_jaxpr(jaxpr_unknown_)
10151017
out_unknowns = [not pval.is_known() for pval in out_pvals]
1016-
if not fwd:
1018+
if type(fwd) is bool and not fwd:
10171019
residuals_ = iter(residuals)
10181020
residuals = [next(residuals_) if f is None else known_vals_in[f]
10191021
for f in fwds]
10201022
assert next(residuals_, None) is None
10211023
fwds = [None] * len(fwds)
10221024
else:
1025+
if type(fwd) is tuple:
1026+
fwd_ = [f for f, uk in zip(fwd, in_unknowns) if not uk]
1027+
residuals_, residuals = iter(residuals), []
1028+
fwds = [residuals.append(next(residuals_)) if f is None else
1029+
residuals.append(known_vals_in[f]) if not fwd_[f] else
1030+
f for f in fwds]
10231031
fwds, residuals = _include_consts_in_fwds(jaxpr.consts, fwds, residuals)
10241032
res_avals = [core.get_aval(r) for r in residuals]
10251033
cell.append((out_unknowns, jaxpr_unknown, res_avals, fwds))

jax/_src/lax/control_flow/loops.py

Lines changed: 86 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from jax._src import util
3838
from jax._src.api_util import (
3939
_check_no_aliased_ref_args, _check_no_aliased_closed_over_refs)
40-
from jax._src.core import ShapedArray
40+
from jax._src.core import ShapedArray, typeof
4141
from jax._src.interpreters import ad
4242
from jax._src.interpreters import batching
4343
from jax._src.interpreters import mlir
@@ -809,6 +809,27 @@ def _const_to_intensive_res_forwarding(
809809
tangent_jaxpr, [False] * num_nz + [i is not None for i in const_to_res])
810810
return primal_jaxpr, tangent_jaxpr, intensive_res, new_in_fwd
811811

812+
def _scan_known_hoisting(jaxpr_known, known_consts, num_res):
813+
# To disable:
814+
# return jaxpr_known, known_consts, [False] * num_res, []
815+
816+
consts = [pe.PartialVal.unknown(a) if isinstance(a := typeof(c), state.AbstractRef)
817+
else pe.PartialVal.known(c) for c in known_consts]
818+
others = _map(pe.PartialVal.unknown, jaxpr_known.in_avals[len(consts):])
819+
num_known_outs = len(jaxpr_known.out_avals) - num_res
820+
with source_info_util.reset_name_stack():
821+
jaxpr_known_, pvals_out, new_known_consts = pe.trace_to_jaxpr_nounits(
822+
lu.wrap_init(core.jaxpr_as_fun(jaxpr_known),
823+
debug_info=jaxpr_known.jaxpr.debug_info),
824+
consts + others, instantiate=[True] * num_known_outs + [False] * num_res)
825+
jaxpr_known = pe.close_jaxpr(pe.convert_constvars_jaxpr(jaxpr_known_))
826+
res_pvals = pvals_out[num_known_outs:]
827+
which_hoisted = [pval.is_known() for pval in res_pvals]
828+
hoisted_res = [pval.get_known() for pval in res_pvals if pval.is_known()]
829+
mut_consts = [c for c in known_consts if isinstance(typeof(c), state.AbstractRef)]
830+
return jaxpr_known, [*new_known_consts, *mut_consts], which_hoisted, hoisted_res
831+
832+
812833
def _scan_partial_eval(trace, *tracers, reverse: bool,
813834
length: int, num_consts: int, num_carry: int,
814835
jaxpr: core.ClosedJaxpr, linear: Sequence[bool],
@@ -819,148 +840,107 @@ def _scan_partial_eval(trace, *tracers, reverse: bool,
819840

820841
# Fixpoint computation of which carry elements are unknown. Each iteration
821842
# promotes at least one carry to unknown. We need at most len(carry)
822-
# iterations, but we need one last iteration to prepare the jaxpr based on the
823-
# final carry_uk.
843+
# iterations to decide carry_uk, plus one to prepare the jaxpr.
824844
carry_uk = init_uk
845+
# Don't allow forwarding from the carry or numpy.ndarrays.
846+
fwd = [(i < num_consts or i >= num_consts + num_carry) and
847+
not isinstance(t.pval.get_known(), np.ndarray) for i, t in enumerate(tracers)]
825848
for _ in range(1 + len(carry_uk)):
826849
unknowns = const_uk + carry_uk + xs_uk
827-
jaxpr_known, jaxpr_unknown, out_uk, res_avals = pe.partial_eval_jaxpr_nounits(
828-
jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys)
850+
jaxpr_known, jaxpr_unknown, out_uk, res_avals, in_fwd_res = \
851+
pe.partial_eval_jaxpr_nounits_fwd(
852+
jaxpr, unknowns, instantiate=carry_uk + [False] * num_ys, fwd=fwd)
829853
carry_uk_out, ys_uk = split_list(out_uk, [num_carry])
830854
if carry_uk_out == carry_uk:
831855
break
832856
else:
833857
carry_uk = _map(operator.or_, carry_uk, carry_uk_out)
834858
else:
835859
assert False, "Fixpoint not reached"
836-
num_res = len(res_avals)
860+
num_res_out, num_res_in = len(res_avals), len(in_fwd_res)
861+
num_knowns_out = len(jaxpr_known.out_avals) - num_res_out
862+
num_consts_known = num_consts - sum(const_uk)
863+
num_carry_known = num_carry - sum(carry_uk)
837864
del res_avals, carry_uk_out
838865

839866
# Instantiate those inputs which must be treated as unknown from the fixpoint.
840-
tracers = tuple(trace.instantiate_const(t) if uk else t
841-
for t, uk in zip(tracers, unknowns))
842-
843-
# The residual inputs and outputs of the jaxprs produced haven't yet been
844-
# adapted to the scan calling convention; in particular, jaxpr_known has its
845-
# residual outputs all at the end, meaning they're extensive outputs (which is
846-
# fully general but may be wasteful for residuals which are loop-invariant)
847-
# while jaxpr_unknown has its corresponding residual inputs at the front (just
848-
# as a convention with partial_eval_jaxpr_nounits), making them constant
849-
# inputs. To make them consistent, we move the residual inputs on
850-
# jaxpr_unknown to the end, even though we may move some back in the sequel.
867+
tracers = [trace.instantiate_const(t) if uk else t
868+
for t, uk in zip(tracers, unknowns)]
869+
known_ins = [t.pval.get_known() for t in tracers if t.pval.is_known()]
870+
unknown_ins = [t for t in tracers if not t.pval.is_known()]
871+
872+
# At this point all non-forwarded residuals are treated as extensive outputs
873+
# of jaxpr_known. Hoist out those that only depend on consts.
874+
# Before: jaxpr_known: [*known_ins] -> [*known_outs, *non_fwd_res]
875+
# After: jaxpr_known: [*known_consts, *known_ins] -> [*known_outs, *ext_res]
876+
# where, modulo hoisted res not being broadcast, we have
877+
# non_fwd_res = merge_lists(which_hoisted, ext_res, hoisted_res)
878+
known_consts, known_ins = split_list(known_ins, [num_consts_known])
879+
jaxpr_known_, known_consts_, which_hoisted, hoisted_res = \
880+
_scan_known_hoisting(jaxpr_known, known_consts, num_res_out)
881+
882+
# To make jaxpr_unknown match the scan calling convention, move to the back
883+
# binders that don't correspond to hoisted or const-forwarded residuals.
884+
# Before: jaxpr_unknown: [*res, *unknown_ins] -> [*unkown_outs]
885+
# After: jaxpr_unkonwn: [*int_res, *unknown_ins, *ext_res] -> [*unknown_outs]
886+
num_unk_in = len(jaxpr_unknown.in_avals) - num_res_in
887+
which_hoisted_ = iter(which_hoisted)
888+
res_to_move = [not next(which_hoisted_) if f is None else
889+
f >= len(jaxpr.consts) + num_consts_known + num_carry_known
890+
for f in in_fwd_res]
891+
assert next(which_hoisted_, None) is None
851892
jaxpr_unknown = pe.move_binders_to_back(
852-
jaxpr_unknown, [True] * num_res + [False] * sum(unknowns))
853-
854-
# At this point, all residuals are treated as extensive outputs of jaxpr_known
855-
# (and extensive inputs to jaxpr_unknown). But residuals that are loop-
856-
# invariant can be hoisted out of the scan, rather than letting them get
857-
# broadcast (as in e.g. scanning multiplication by a constant matrix; we don't
858-
# want to broadcast the matrix!). So, outside the loop we perform a partial
859-
# evaluation with known 'const' inputs (but all other inputs unknown).
860-
const_pvals = [pe.PartialVal.known(t.pval.get_known())
861-
if not isinstance(t.aval, state.AbstractRef)
862-
else pe.PartialVal.unknown(t.aval)
863-
for t in tracers[:num_consts] if t.pval.is_known()]
864-
other_pvals = [pe.PartialVal.unknown(aval)
865-
for aval in jaxpr_known.in_avals[len(const_pvals):]]
866-
with source_info_util.reset_name_stack():
867-
jaxpr_known_, invar_pvals_out, jaxpr_known_consts = pe.trace_to_jaxpr_nounits(
868-
lu.wrap_init(core.jaxpr_as_fun(jaxpr_known),
869-
debug_info=jaxpr_known.jaxpr.debug_info),
870-
const_pvals + other_pvals,
871-
instantiate=[True] * (len(out_uk) - sum(out_uk)) + [False] * num_res)
872-
jaxpr_known = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_known_), ())
873-
# The above trace_to_jaxpr_nounits call computed loop-invariant residuals
874-
# (known values in invar_pvals_out) and also computed loop-invariant values
875-
# needed by the new jaxpr_known (in jaxpr_known_consts, which replace the
876-
# previous consts). We need to collect the computed intensive residuals, and
877-
# move corresponding intensive residual binders in jaxpr_unknown to the front.
878-
res_pvals = invar_pvals_out[len(invar_pvals_out) - num_res:]
879-
intensive_res = [pval.get_known() for pval in res_pvals if pval.is_known()]
880-
jaxpr_unknown = pe.move_binders_to_front(
881-
jaxpr_unknown,
882-
[False] * sum(unknowns) + [pval.is_known() for pval in res_pvals])
883-
del const_pvals, other_pvals, invar_pvals_out, jaxpr_known_, res_pvals
884-
# We use `jaxpr_known_consts` when we call scan_p.bind with jaxpr_known, and
885-
# we use `intensive_res` when we build the jaxpr eqn with jaxpr_unknown.
886-
887-
# As another optimization, for any extensive inputs that are just forwarded to
888-
# extensive outputs, to avoid a copy (which would be looping over
889-
# dynamic-update-slice) we'd rather forward the input tracer/value. That means
890-
# pruning some outputs from jaxpr_known here, and updating `out_flat` below.
891-
fwds_known = pe._jaxpr_forwarding(jaxpr_known.jaxpr)
892-
# Prune fwds_known to include only extensive input to extensive output.
893-
fwds_known = [in_idx if out_idx >= num_carry - sum(carry_uk) and
894-
in_idx is not None and
895-
in_idx >= len(jaxpr_known_consts) + num_carry - sum(carry_uk)
896-
else None for out_idx, in_idx in enumerate(fwds_known)]
897-
# Drop any extensive output we can instead get by forwarding an input.
898-
# TODO(mattjj): use pe.dce_jaxpr here, though need a fixpoint
899-
jaxpr_known_, () = jaxpr_known.jaxpr, jaxpr_known.consts
900-
jaxpr_known_ = jaxpr_known_.replace(
901-
outvars=[x for x, i in zip(jaxpr_known_.outvars, fwds_known) if i is None])
902-
jaxpr_known = core.ClosedJaxpr(jaxpr_known_, ())
903-
del jaxpr_known_
904-
# We use `fwds_known` below when forming the output of scanning jaxpr_known.
893+
jaxpr_unknown, res_to_move + [False] * num_unk_in)
905894

906895
# Run the known part of the scan (if it has any outputs or effects).
907-
known_mutable_consts = [t.pval.get_known() for t in tracers[:num_consts]
908-
if t.pval.is_known() and isinstance(t.aval, state.AbstractRef)]
909-
known_inputs = (list(jaxpr_known_consts) + known_mutable_consts +
910-
[t.pval.get_known() for t in tracers[num_consts:]
911-
if t.pval.is_known()])
912-
if not jaxpr_known.out_avals and not jaxpr_known.effects:
913-
out_known = []
896+
linear_known, linear_unknown = partition_list(unknowns, linear)
897+
if not jaxpr_known_.out_avals and not jaxpr_known_.effects:
898+
known_outs_ext_res = []
914899
else:
915-
linear_known = [False] * len(known_inputs) # conservative!
916-
out_known = scan_p.bind(
917-
*known_inputs, reverse=reverse, length=length, jaxpr=jaxpr_known,
918-
num_consts=len(jaxpr_known_consts) + len(known_mutable_consts),
919-
num_carry=num_carry - sum(carry_uk),
920-
linear=tuple(linear_known), unroll=unroll,
900+
linear_known = [False] * len(jaxpr_known_.in_avals) # TODO conservative
901+
assert len(known_consts_) + len(known_ins) == len(jaxpr_known_.in_avals)
902+
known_outs_ext_res = scan_p.bind(
903+
*known_consts_, *known_ins, jaxpr=jaxpr_known_, reverse=reverse,
904+
length=length, num_consts=len(known_consts_),
905+
num_carry=num_carry_known, linear=(*linear_known,), unroll=unroll,
921906
_split_transpose=_split_transpose)
922-
del linear_known
923-
# Complete the known output by filling in forwarded values using fwds_known.
924-
out_known_iter = iter(out_known)
925-
out_known = [next(out_known_iter) if f is None
926-
else _maybe_put(known_inputs[f]) for f in fwds_known]
927-
assert next(out_known_iter, None) is None
928-
del known_inputs, out_known_iter
929-
930-
# Split known outputs from residuals.
931-
out_known, extensive_res = split_list(out_known, [len(out_uk) - sum(out_uk)])
932-
assert len(intensive_res) + len(extensive_res) == num_res
907+
known_outs, ext_res = split_list(known_outs_ext_res, [num_knowns_out])
908+
909+
# Complete non_fwd_res and then res, then split to match binders.
910+
non_fwd_res = merge_lists(which_hoisted, ext_res, hoisted_res)
911+
non_fwd_res_ = iter(non_fwd_res)
912+
res = [next(non_fwd_res_) if f is None
913+
else [*jaxpr.consts, *known_consts, *known_ins][f] for f in in_fwd_res]
914+
assert next(non_fwd_res_, None) is None
915+
int_res, ext_res = partition_list(res_to_move, res)
933916

934917
# Create input tracers for jaxpr_unknown bind.
935918
unknown_inputs = [t for t in tracers if not t.pval.is_known()]
936-
intensive_res = _map(trace.new_instantiated_const, intensive_res)
937-
extensive_res = _map(trace.new_instantiated_const, extensive_res)
919+
int_res = _map(trace.new_instantiated_const, int_res)
920+
ext_res = _map(trace.new_instantiated_const, ext_res)
938921
# Create output tracers for jaxpr_unknown bind, adapting extensive shapes.
939922
carry_avals, y_avals = split_list(jaxpr_unknown.out_avals, [sum(carry_uk)])
940-
ys_avals = [core.unmapped_aval(length, 0, y_aval)
941-
for y_aval in y_avals]
923+
ys_avals = [core.unmapped_aval(length, 0, y_aval) for y_aval in y_avals]
942924
out_tracers = [pe.JaxprTracer(trace, pe.PartialVal.unknown(a), None)
943925
for a in it.chain(carry_avals, ys_avals)]
944926
del carry_avals, y_avals
945927
# Create equation.
946-
linear_unknown = tuple([False] * len(intensive_res) +
947-
[l for l, uk in zip(linear, unknowns) if uk] +
948-
[False] * len(extensive_res))
928+
linear_unknown = [False] * len(int_res) + linear_unknown + [False] * len(ext_res)
929+
assert len(linear_unknown) == len(jaxpr_unknown.in_avals)
949930
name_stack = source_info_util.current_name_stack()[len(trace.name_stack):]
950931
source = source_info_util.current().replace(name_stack=name_stack)
951-
assert len(out_tracers) == len(jaxpr_unknown.out_avals)
952-
eqn = pe.new_eqn_recipe(trace, [*intensive_res, *unknown_inputs, *extensive_res],
932+
eqn = pe.new_eqn_recipe(trace, [*int_res, *unknown_inputs, *ext_res],
953933
out_tracers, scan_p,
954934
dict(reverse=reverse, length=length, unroll=unroll,
955-
jaxpr=jaxpr_unknown, linear=linear_unknown,
956-
num_consts=len(intensive_res) + sum(const_uk),
935+
jaxpr=jaxpr_unknown, linear=(*linear_unknown,),
936+
num_consts=len(int_res) + sum(const_uk),
957937
num_carry=sum(carry_uk),
958938
_split_transpose=_split_transpose),
959939
jaxpr_unknown.effects, source)
960940
for t in out_tracers: t.recipe = eqn
961941

962942
# Merge known and unknown outputs into final result.
963-
return util.merge_lists(out_uk, out_known, out_tracers)
943+
return util.merge_lists(out_uk, known_outs, out_tracers)
964944

965945
def _maybe_put(x):
966946
if isinstance(x, np.ndarray):

tests/lax_control_flow_test.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import contextlib
1818
from functools import partial
1919
import itertools
20+
import math
2021
import operator
2122
import re
2223
import unittest
@@ -3370,6 +3371,90 @@ def g(x):
33703371
jaxpr = jax.make_jaxpr(g)(jnp.arange(3.))
33713372
self.assertLen(jaxpr.eqns[0].params["jaxpr"].jaxpr.outvars, 1)
33723373

3374+
@parameterized.parameters(itertools.product([0, 4], repeat=5))
3375+
@jtu.run_on_devices("cpu")
3376+
def test_scan_vjp_forwarding_correctness(
3377+
self,
3378+
seed,
3379+
num_rule_consts,
3380+
num_const_fwds,
3381+
num_carry_fwds,
3382+
num_input_fwds):
3383+
# Unlike test_scan_forwarding_correctness, which tests forwarding in the
3384+
# scan traceable, this test covers forwarding logic related to residuals in
3385+
# the scan partial eval / vjp rule. So 'forwards' refer to residuals that
3386+
# will be forwarded.
3387+
3388+
# We use a custom_jvp where the jvp rule introduces consts to populate
3389+
# jaxpr.consts in _scan_partial_eval's input.
3390+
@jax.custom_jvp
3391+
def foo(x):
3392+
return 3. * x
3393+
@foo.defjvp
3394+
def foo_jvp(primals, tangents):
3395+
(x,), (x_dot,) = primals, tangents
3396+
if num_rule_consts:
3397+
coeff = sum([jnp.array(np.ones(3) / num_rule_consts) for _ in range(num_rule_consts)])
3398+
else:
3399+
coeff = 1.
3400+
return foo(x), jnp.prod(coeff) * x_dot
3401+
3402+
num_const = num_const_fwds + 2
3403+
num_carry = num_carry_fwds + 4
3404+
num_xs = num_input_fwds + 2
3405+
num_ys = num_xs + 1
3406+
3407+
rng = np.random.RandomState(seed)
3408+
carry_perm = rng.permutation(num_carry)
3409+
carry_iperm = np.argsort(carry_perm)
3410+
3411+
xs_perm = rng.permutation(num_xs)
3412+
ys_perm = rng.permutation(num_ys)
3413+
f = np.arange(num_xs)
3414+
f = [f[i] if idx < num_input_fwds else None for idx, i in enumerate(xs_perm)]
3415+
f += [None]
3416+
in_fwd = [f[i] for i in ys_perm]
3417+
3418+
body_consts = [jnp.array(rng.randn(3)) for _ in range(num_const)]
3419+
init_vals = list(map(jnp.array, rng.uniform(size=(num_carry, 3))))
3420+
3421+
def body_fun(c, x):
3422+
c = [c[i] for i in carry_iperm]
3423+
3424+
const_fwds, const_dont_fwd = split_list(body_consts, [num_const_fwds])
3425+
z = sum(const_dont_fwd)
3426+
3427+
carry_fwds, carry_dont_fwd = split_list(c, [num_const_fwds])
3428+
carry_fwds = [math.prod([x, x, *const_fwds, z]) for x in carry_fwds]
3429+
carry_dont_fwd = [jnp.sin(x) * sum(jnp.sum(c) for c in body_consts)
3430+
for x in carry_dont_fwd]
3431+
new_c_perm = [*carry_fwds, *carry_dont_fwd]
3432+
new_c = [new_c_perm[i] for i in carry_perm]
3433+
new_c = [foo(new_c[0]), *new_c[1:]]
3434+
3435+
x = [x[i] for i in xs_perm]
3436+
x_fwd, x_dont_fwd = split_list(x, [num_input_fwds])
3437+
x_fwd = [x * x for x in x_fwd]
3438+
x_dont_fwd = [jnp.cos(x) * sum(jnp.sum(c) for c in body_consts)
3439+
for x in x_dont_fwd]
3440+
y = [*x_fwd, *x_dont_fwd, 0]
3441+
y = [y[i] for i in ys_perm]
3442+
3443+
return new_c, y
3444+
3445+
xs = list(map(jnp.array, rng.uniform(size=(num_xs, 2))))
3446+
3447+
(final, outs), vjp = jax.vjp(partial(jax.lax.scan, body_fun), init_vals, xs)
3448+
init_vals_bar, xs_bar = vjp((final, outs))
3449+
3450+
with jax.disable_jit():
3451+
(final_ref, outs_ref), vjp = jax.vjp(partial(jax.lax.scan, body_fun), init_vals, xs)
3452+
init_vals_bar_ref, xs_bar_ref = vjp((final, outs))
3453+
3454+
self.assertAllClose(final, final_ref, check_dtypes=False)
3455+
self.assertAllClose(outs, outs_ref, check_dtypes=False)
3456+
self.assertAllClose(xs_bar, xs_bar_ref, check_dtypes=False)
3457+
33733458

33743459
if __name__ == '__main__':
33753460
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)