@@ -872,11 +872,11 @@ def _scan_partial_eval(trace, *tracers, reverse: bool,
872
872
# At this point all non-forwarded residuals are treated as extensive outputs
873
873
# of jaxpr_known. Hoist out those that only depend on consts.
874
874
# Before: jaxpr_known: [*known_ins] -> [*known_outs, *non_fwd_res]
875
- # After: jaxpr_known: [*known_consts , *known_ins] -> [*known_outs, *ext_res]
875
+ # After: jaxpr_known: [*known_consts_ , *known_ins] -> [*known_outs, *ext_res]
876
876
# where, modulo hoisted res not being broadcast, we have
877
877
# non_fwd_res = merge_lists(which_hoisted, ext_res, hoisted_res)
878
878
known_consts , known_ins = split_list (known_ins , [num_consts_known ])
879
- jaxpr_known_ , known_consts_ , which_hoisted , hoisted_res = \
879
+ jaxpr_known , known_consts_ , which_hoisted , hoisted_res = \
880
880
_scan_known_hoisting (jaxpr_known , known_consts , num_res_out )
881
881
882
882
# To make jaxpr_unknown match the scan calling convention, move to the back
@@ -894,13 +894,13 @@ def _scan_partial_eval(trace, *tracers, reverse: bool,
894
894
895
895
# Run the known part of the scan (if it has any outputs or effects).
896
896
linear_known , linear_unknown = partition_list (unknowns , linear )
897
- if not jaxpr_known_ .out_avals and not jaxpr_known_ .effects :
897
+ if not jaxpr_known .out_avals and not jaxpr_known .effects :
898
898
known_outs_ext_res = []
899
899
else :
900
- linear_known = [False ] * len (jaxpr_known_ .in_avals ) # TODO conservative
901
- assert len (known_consts_ ) + len (known_ins ) == len (jaxpr_known_ .in_avals )
900
+ linear_known = [False ] * len (jaxpr_known .in_avals ) # TODO conservative
901
+ assert len (known_consts_ ) + len (known_ins ) == len (jaxpr_known .in_avals )
902
902
known_outs_ext_res = scan_p .bind (
903
- * known_consts_ , * known_ins , jaxpr = jaxpr_known_ , reverse = reverse ,
903
+ * known_consts_ , * known_ins , jaxpr = jaxpr_known , reverse = reverse ,
904
904
length = length , num_consts = len (known_consts_ ),
905
905
num_carry = num_carry_known , linear = (* linear_known ,), unroll = unroll ,
906
906
_split_transpose = _split_transpose )
0 commit comments