Skip to content

Commit cfb382e

Browse files
committed
tweak
1 parent b72bb88 commit cfb382e

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

jax/_src/lax/control_flow/loops.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -872,11 +872,11 @@ def _scan_partial_eval(trace, *tracers, reverse: bool,
872872
# At this point all non-forwarded residuals are treated as extensive outputs
873873
# of jaxpr_known. Hoist out those that only depend on consts.
874874
# Before: jaxpr_known: [*known_ins] -> [*known_outs, *non_fwd_res]
875-
# After: jaxpr_known: [*known_consts, *known_ins] -> [*known_outs, *ext_res]
875+
# After: jaxpr_known: [*known_consts_, *known_ins] -> [*known_outs, *ext_res]
876876
# where, modulo hoisted res not being broadcast, we have
877877
# non_fwd_res = merge_lists(which_hoisted, ext_res, hoisted_res)
878878
known_consts, known_ins = split_list(known_ins, [num_consts_known])
879-
jaxpr_known_, known_consts_, which_hoisted, hoisted_res = \
879+
jaxpr_known, known_consts_, which_hoisted, hoisted_res = \
880880
_scan_known_hoisting(jaxpr_known, known_consts, num_res_out)
881881

882882
# To make jaxpr_unknown match the scan calling convention, move to the back
@@ -894,13 +894,13 @@ def _scan_partial_eval(trace, *tracers, reverse: bool,
894894

895895
# Run the known part of the scan (if it has any outputs or effects).
896896
linear_known, linear_unknown = partition_list(unknowns, linear)
897-
if not jaxpr_known_.out_avals and not jaxpr_known_.effects:
897+
if not jaxpr_known.out_avals and not jaxpr_known.effects:
898898
known_outs_ext_res = []
899899
else:
900-
linear_known = [False] * len(jaxpr_known_.in_avals) # TODO conservative
901-
assert len(known_consts_) + len(known_ins) == len(jaxpr_known_.in_avals)
900+
linear_known = [False] * len(jaxpr_known.in_avals) # TODO conservative
901+
assert len(known_consts_) + len(known_ins) == len(jaxpr_known.in_avals)
902902
known_outs_ext_res = scan_p.bind(
903-
*known_consts_, *known_ins, jaxpr=jaxpr_known_, reverse=reverse,
903+
*known_consts_, *known_ins, jaxpr=jaxpr_known, reverse=reverse,
904904
length=length, num_consts=len(known_consts_),
905905
num_carry=num_carry_known, linear=(*linear_known,), unroll=unroll,
906906
_split_transpose=_split_transpose)

0 commit comments

Comments
 (0)