37
37
from jax ._src import util
38
38
from jax ._src .api_util import (
39
39
_check_no_aliased_ref_args , _check_no_aliased_closed_over_refs )
40
- from jax ._src .core import ShapedArray
40
+ from jax ._src .core import ShapedArray , typeof
41
41
from jax ._src .interpreters import ad
42
42
from jax ._src .interpreters import batching
43
43
from jax ._src .interpreters import mlir
@@ -809,6 +809,27 @@ def _const_to_intensive_res_forwarding(
809
809
tangent_jaxpr , [False ] * num_nz + [i is not None for i in const_to_res ])
810
810
return primal_jaxpr , tangent_jaxpr , intensive_res , new_in_fwd
811
811
812
+ def _scan_known_hoisting (jaxpr_known , known_consts , num_res ):
813
+ # To disable:
814
+ # return jaxpr_known, known_consts, [False] * num_res, []
815
+
816
+ consts = [pe .PartialVal .unknown (a ) if isinstance (a := typeof (c ), state .AbstractRef )
817
+ else pe .PartialVal .known (c ) for c in known_consts ]
818
+ others = _map (pe .PartialVal .unknown , jaxpr_known .in_avals [len (consts ):])
819
+ num_known_outs = len (jaxpr_known .out_avals ) - num_res
820
+ with source_info_util .reset_name_stack ():
821
+ jaxpr_known_ , pvals_out , new_known_consts = pe .trace_to_jaxpr_nounits (
822
+ lu .wrap_init (core .jaxpr_as_fun (jaxpr_known ),
823
+ debug_info = jaxpr_known .jaxpr .debug_info ),
824
+ consts + others , instantiate = [True ] * num_known_outs + [False ] * num_res )
825
+ jaxpr_known = pe .close_jaxpr (pe .convert_constvars_jaxpr (jaxpr_known_ ))
826
+ res_pvals = pvals_out [num_known_outs :]
827
+ which_hoisted = [pval .is_known () for pval in res_pvals ]
828
+ hoisted_res = [pval .get_known () for pval in res_pvals if pval .is_known ()]
829
+ mut_consts = [c for c in known_consts if isinstance (typeof (c ), state .AbstractRef )]
830
+ return jaxpr_known , [* new_known_consts , * mut_consts ], which_hoisted , hoisted_res
831
+
832
+
812
833
def _scan_partial_eval (trace , * tracers , reverse : bool ,
813
834
length : int , num_consts : int , num_carry : int ,
814
835
jaxpr : core .ClosedJaxpr , linear : Sequence [bool ],
@@ -819,148 +840,107 @@ def _scan_partial_eval(trace, *tracers, reverse: bool,
819
840
820
841
# Fixpoint computation of which carry elements are unknown. Each iteration
821
842
# promotes at least one carry to unknown. We need at most len(carry)
822
- # iterations, but we need one last iteration to prepare the jaxpr based on the
823
- # final carry_uk.
843
+ # iterations to decide carry_uk, plus one to prepare the jaxpr.
824
844
carry_uk = init_uk
845
+ # Don't allow forwarding from the carry or numpy.ndarrays.
846
+ fwd = [(i < num_consts or i >= num_consts + num_carry ) and
847
+ not isinstance (t .pval .get_known (), np .ndarray ) for i , t in enumerate (tracers )]
825
848
for _ in range (1 + len (carry_uk )):
826
849
unknowns = const_uk + carry_uk + xs_uk
827
- jaxpr_known , jaxpr_unknown , out_uk , res_avals = pe .partial_eval_jaxpr_nounits (
828
- jaxpr , unknowns , instantiate = carry_uk + [False ] * num_ys )
850
+ jaxpr_known , jaxpr_unknown , out_uk , res_avals , in_fwd_res = \
851
+ pe .partial_eval_jaxpr_nounits_fwd (
852
+ jaxpr , unknowns , instantiate = carry_uk + [False ] * num_ys , fwd = fwd )
829
853
carry_uk_out , ys_uk = split_list (out_uk , [num_carry ])
830
854
if carry_uk_out == carry_uk :
831
855
break
832
856
else :
833
857
carry_uk = _map (operator .or_ , carry_uk , carry_uk_out )
834
858
else :
835
859
assert False , "Fixpoint not reached"
836
- num_res = len (res_avals )
860
+ num_res_out , num_res_in = len (res_avals ), len (in_fwd_res )
861
+ num_knowns_out = len (jaxpr_known .out_avals ) - num_res_out
862
+ num_consts_known = num_consts - sum (const_uk )
863
+ num_carry_known = num_carry - sum (carry_uk )
837
864
del res_avals , carry_uk_out
838
865
839
866
# Instantiate those inputs which must be treated as unknown from the fixpoint.
840
- tracers = tuple (trace .instantiate_const (t ) if uk else t
841
- for t , uk in zip (tracers , unknowns ))
842
-
843
- # The residual inputs and outputs of the jaxprs produced haven't yet been
844
- # adapted to the scan calling convention; in particular, jaxpr_known has its
845
- # residual outputs all at the end, meaning they're extensive outputs (which is
846
- # fully general but may be wasteful for residuals which are loop-invariant)
847
- # while jaxpr_unknown has its corresponding residual inputs at the front (just
848
- # as a convention with partial_eval_jaxpr_nounits), making them constant
849
- # inputs. To make them consistent, we move the residual inputs on
850
- # jaxpr_unknown to the end, even though we may move some back in the sequel.
867
+ tracers = [trace .instantiate_const (t ) if uk else t
868
+ for t , uk in zip (tracers , unknowns )]
869
+ known_ins = [t .pval .get_known () for t in tracers if t .pval .is_known ()]
870
+ unknown_ins = [t for t in tracers if not t .pval .is_known ()]
871
+
872
+ # At this point all non-forwarded residuals are treated as extensive outputs
873
+ # of jaxpr_known. Hoist out those that only depend on consts.
874
+ # Before: jaxpr_known: [*known_ins] -> [*known_outs, *non_fwd_res]
875
+ # After: jaxpr_known: [*known_consts, *known_ins] -> [*known_outs, *ext_res]
876
+ # where, modulo hoisted res not being broadcast, we have
877
+ # non_fwd_res = merge_lists(which_hoisted, ext_res, hoisted_res)
878
+ known_consts , known_ins = split_list (known_ins , [num_consts_known ])
879
+ jaxpr_known_ , known_consts_ , which_hoisted , hoisted_res = \
880
+ _scan_known_hoisting (jaxpr_known , known_consts , num_res_out )
881
+
882
+ # To make jaxpr_unknown match the scan calling convention, move to the back
883
+ # binders that don't correspond to hoisted or const-forwarded residuals.
884
+ # Before: jaxpr_unknown: [*res, *unknown_ins] -> [*unkown_outs]
885
+ # After: jaxpr_unkonwn: [*int_res, *unknown_ins, *ext_res] -> [*unknown_outs]
886
+ num_unk_in = len (jaxpr_unknown .in_avals ) - num_res_in
887
+ which_hoisted_ = iter (which_hoisted )
888
+ res_to_move = [not next (which_hoisted_ ) if f is None else
889
+ f >= len (jaxpr .consts ) + num_consts_known + num_carry_known
890
+ for f in in_fwd_res ]
891
+ assert next (which_hoisted_ , None ) is None
851
892
jaxpr_unknown = pe .move_binders_to_back (
852
- jaxpr_unknown , [True ] * num_res + [False ] * sum (unknowns ))
853
-
854
- # At this point, all residuals are treated as extensive outputs of jaxpr_known
855
- # (and extensive inputs to jaxpr_unknown). But residuals that are loop-
856
- # invariant can be hoisted out of the scan, rather than letting them get
857
- # broadcast (as in e.g. scanning multiplication by a constant matrix; we don't
858
- # want to broadcast the matrix!). So, outside the loop we perform a partial
859
- # evaluation with known 'const' inputs (but all other inputs unknown).
860
- const_pvals = [pe .PartialVal .known (t .pval .get_known ())
861
- if not isinstance (t .aval , state .AbstractRef )
862
- else pe .PartialVal .unknown (t .aval )
863
- for t in tracers [:num_consts ] if t .pval .is_known ()]
864
- other_pvals = [pe .PartialVal .unknown (aval )
865
- for aval in jaxpr_known .in_avals [len (const_pvals ):]]
866
- with source_info_util .reset_name_stack ():
867
- jaxpr_known_ , invar_pvals_out , jaxpr_known_consts = pe .trace_to_jaxpr_nounits (
868
- lu .wrap_init (core .jaxpr_as_fun (jaxpr_known ),
869
- debug_info = jaxpr_known .jaxpr .debug_info ),
870
- const_pvals + other_pvals ,
871
- instantiate = [True ] * (len (out_uk ) - sum (out_uk )) + [False ] * num_res )
872
- jaxpr_known = pe .ClosedJaxpr (pe .convert_constvars_jaxpr (jaxpr_known_ ), ())
873
- # The above trace_to_jaxpr_nounits call computed loop-invariant residuals
874
- # (known values in invar_pvals_out) and also computed loop-invariant values
875
- # needed by the new jaxpr_known (in jaxpr_known_consts, which replace the
876
- # previous consts). We need to collect the computed intensive residuals, and
877
- # move corresponding intensive residual binders in jaxpr_unknown to the front.
878
- res_pvals = invar_pvals_out [len (invar_pvals_out ) - num_res :]
879
- intensive_res = [pval .get_known () for pval in res_pvals if pval .is_known ()]
880
- jaxpr_unknown = pe .move_binders_to_front (
881
- jaxpr_unknown ,
882
- [False ] * sum (unknowns ) + [pval .is_known () for pval in res_pvals ])
883
- del const_pvals , other_pvals , invar_pvals_out , jaxpr_known_ , res_pvals
884
- # We use `jaxpr_known_consts` when we call scan_p.bind with jaxpr_known, and
885
- # we use `intensive_res` when we build the jaxpr eqn with jaxpr_unknown.
886
-
887
- # As another optimization, for any extensive inputs that are just forwarded to
888
- # extensive outputs, to avoid a copy (which would be looping over
889
- # dynamic-update-slice) we'd rather forward the input tracer/value. That means
890
- # pruning some outputs from jaxpr_known here, and updating `out_flat` below.
891
- fwds_known = pe ._jaxpr_forwarding (jaxpr_known .jaxpr )
892
- # Prune fwds_known to include only extensive input to extensive output.
893
- fwds_known = [in_idx if out_idx >= num_carry - sum (carry_uk ) and
894
- in_idx is not None and
895
- in_idx >= len (jaxpr_known_consts ) + num_carry - sum (carry_uk )
896
- else None for out_idx , in_idx in enumerate (fwds_known )]
897
- # Drop any extensive output we can instead get by forwarding an input.
898
- # TODO(mattjj): use pe.dce_jaxpr here, though need a fixpoint
899
- jaxpr_known_ , () = jaxpr_known .jaxpr , jaxpr_known .consts
900
- jaxpr_known_ = jaxpr_known_ .replace (
901
- outvars = [x for x , i in zip (jaxpr_known_ .outvars , fwds_known ) if i is None ])
902
- jaxpr_known = core .ClosedJaxpr (jaxpr_known_ , ())
903
- del jaxpr_known_
904
- # We use `fwds_known` below when forming the output of scanning jaxpr_known.
893
+ jaxpr_unknown , res_to_move + [False ] * num_unk_in )
905
894
906
895
# Run the known part of the scan (if it has any outputs or effects).
907
- known_mutable_consts = [t .pval .get_known () for t in tracers [:num_consts ]
908
- if t .pval .is_known () and isinstance (t .aval , state .AbstractRef )]
909
- known_inputs = (list (jaxpr_known_consts ) + known_mutable_consts +
910
- [t .pval .get_known () for t in tracers [num_consts :]
911
- if t .pval .is_known ()])
912
- if not jaxpr_known .out_avals and not jaxpr_known .effects :
913
- out_known = []
896
+ linear_known , linear_unknown = partition_list (unknowns , linear )
897
+ if not jaxpr_known_ .out_avals and not jaxpr_known_ .effects :
898
+ known_outs_ext_res = []
914
899
else :
915
- linear_known = [False ] * len (known_inputs ) # conservative!
916
- out_known = scan_p . bind (
917
- * known_inputs , reverse = reverse , length = length , jaxpr = jaxpr_known ,
918
- num_consts = len ( jaxpr_known_consts ) + len ( known_mutable_consts ) ,
919
- num_carry = num_carry - sum ( carry_uk ),
920
- linear = tuple ( linear_known ), unroll = unroll ,
900
+ linear_known = [False ] * len (jaxpr_known_ . in_avals ) # TODO conservative
901
+ assert len ( known_consts_ ) + len ( known_ins ) == len ( jaxpr_known_ . in_avals )
902
+ known_outs_ext_res = scan_p . bind (
903
+ * known_consts_ , * known_ins , jaxpr = jaxpr_known_ , reverse = reverse ,
904
+ length = length , num_consts = len ( known_consts_ ),
905
+ num_carry = num_carry_known , linear = ( * linear_known , ), unroll = unroll ,
921
906
_split_transpose = _split_transpose )
922
- del linear_known
923
- # Complete the known output by filling in forwarded values using fwds_known.
924
- out_known_iter = iter (out_known )
925
- out_known = [next (out_known_iter ) if f is None
926
- else _maybe_put (known_inputs [f ]) for f in fwds_known ]
927
- assert next (out_known_iter , None ) is None
928
- del known_inputs , out_known_iter
929
-
930
- # Split known outputs from residuals.
931
- out_known , extensive_res = split_list (out_known , [len (out_uk ) - sum (out_uk )])
932
- assert len (intensive_res ) + len (extensive_res ) == num_res
907
+ known_outs , ext_res = split_list (known_outs_ext_res , [num_knowns_out ])
908
+
909
+ # Complete non_fwd_res and then res, then split to match binders.
910
+ non_fwd_res = merge_lists (which_hoisted , ext_res , hoisted_res )
911
+ non_fwd_res_ = iter (non_fwd_res )
912
+ res = [next (non_fwd_res_ ) if f is None
913
+ else [* jaxpr .consts , * known_consts , * known_ins ][f ] for f in in_fwd_res ]
914
+ assert next (non_fwd_res_ , None ) is None
915
+ int_res , ext_res = partition_list (res_to_move , res )
933
916
934
917
# Create input tracers for jaxpr_unknown bind.
935
918
unknown_inputs = [t for t in tracers if not t .pval .is_known ()]
936
- intensive_res = _map (trace .new_instantiated_const , intensive_res )
937
- extensive_res = _map (trace .new_instantiated_const , extensive_res )
919
+ int_res = _map (trace .new_instantiated_const , int_res )
920
+ ext_res = _map (trace .new_instantiated_const , ext_res )
938
921
# Create output tracers for jaxpr_unknown bind, adapting extensive shapes.
939
922
carry_avals , y_avals = split_list (jaxpr_unknown .out_avals , [sum (carry_uk )])
940
- ys_avals = [core .unmapped_aval (length , 0 , y_aval )
941
- for y_aval in y_avals ]
923
+ ys_avals = [core .unmapped_aval (length , 0 , y_aval ) for y_aval in y_avals ]
942
924
out_tracers = [pe .JaxprTracer (trace , pe .PartialVal .unknown (a ), None )
943
925
for a in it .chain (carry_avals , ys_avals )]
944
926
del carry_avals , y_avals
945
927
# Create equation.
946
- linear_unknown = tuple ([False ] * len (intensive_res ) +
947
- [l for l , uk in zip (linear , unknowns ) if uk ] +
948
- [False ] * len (extensive_res ))
928
+ linear_unknown = [False ] * len (int_res ) + linear_unknown + [False ] * len (ext_res )
929
+ assert len (linear_unknown ) == len (jaxpr_unknown .in_avals )
949
930
name_stack = source_info_util .current_name_stack ()[len (trace .name_stack ):]
950
931
source = source_info_util .current ().replace (name_stack = name_stack )
951
- assert len (out_tracers ) == len (jaxpr_unknown .out_avals )
952
- eqn = pe .new_eqn_recipe (trace , [* intensive_res , * unknown_inputs , * extensive_res ],
932
+ eqn = pe .new_eqn_recipe (trace , [* int_res , * unknown_inputs , * ext_res ],
953
933
out_tracers , scan_p ,
954
934
dict (reverse = reverse , length = length , unroll = unroll ,
955
- jaxpr = jaxpr_unknown , linear = linear_unknown ,
956
- num_consts = len (intensive_res ) + sum (const_uk ),
935
+ jaxpr = jaxpr_unknown , linear = ( * linear_unknown ,) ,
936
+ num_consts = len (int_res ) + sum (const_uk ),
957
937
num_carry = sum (carry_uk ),
958
938
_split_transpose = _split_transpose ),
959
939
jaxpr_unknown .effects , source )
960
940
for t in out_tracers : t .recipe = eqn
961
941
962
942
# Merge known and unknown outputs into final result.
963
- return util .merge_lists (out_uk , out_known , out_tracers )
943
+ return util .merge_lists (out_uk , known_outs , out_tracers )
964
944
965
945
def _maybe_put (x ):
966
946
if isinstance (x , np .ndarray ):
0 commit comments