Skip to content

Commit a439253

Browse files
committed
Try to add constants directly into core.Jaxpr
DO_NOT_SUBMIT
1 parent e818940 commit a439253

File tree

9 files changed

+55
-40
lines changed

9 files changed

+55
-40
lines changed

jax/_src/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2425,7 +2425,7 @@ def make_jaxpr_f(*args, **kwargs):
24252425
if traced._num_consts:
24262426
consts, _ = split_list(traced._args_flat, [traced._num_consts])
24272427
jaxpr_ = pe.convert_invars_to_constvars(traced.jaxpr.jaxpr,
2428-
traced._num_consts)
2428+
traced._num_consts, consts)
24292429
jaxpr = core.ClosedJaxpr(jaxpr_, consts)
24302430
else:
24312431
jaxpr = traced.jaxpr

jax/_src/core.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888

8989
class Jaxpr:
9090
__slots__ = ['__weakref__', '_constvars', '_invars', '_outvars', '_eqns',
91-
'_effects', '_debug_info', '_is_high']
91+
'_effects', '_debug_info', '_is_high', '_consts']
9292

9393
_constvars: list[Var]
9494
_invars: list[Var]
@@ -126,14 +126,19 @@ def debug_info(self) -> DebugInfo:
126126
def is_high(self) -> bool:
127127
return self._is_high
128128

129+
@property
130+
def consts(self) -> list[Any]:
131+
return self._consts
132+
129133
def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
130134
outvars: Sequence[Atom], eqns: Sequence[JaxprEqn],
131135
effects: Effects = no_effects,
132136
# We want all calls to pass a DebugInfo object, but for backwards
133137
# compatibility we have to allow calls when the debug_info
134138
# is missing.
135139
debug_info: DebugInfo = None, # type: ignore[annotation-type-mismatch,assignment]
136-
is_high: bool = False,
140+
is_high: bool = False, *,
141+
consts: Sequence[Any] = (),
137142
):
138143
"""
139144
Args:
@@ -146,6 +151,7 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
146151
effects: set of effects. The effects on a jaxpr are a superset of the
147152
union of the effects for each equation.
148153
debug_info: debugging information.
154+
consts: the constant values corresponding to the constvars
149155
"""
150156
self._constvars = list(constvars)
151157
self._invars = list(invars)
@@ -159,7 +165,10 @@ def __init__(self, constvars: Sequence[Var], invars: Sequence[Var],
159165
# assert (len(debug_info.arg_names) == len(invars)), (debug_info, invars)
160166
# assert (len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)
161167
self._is_high = is_high
162-
num_vars = len(constvars) + len(invars)
168+
self._consts = list(consts)
169+
if len(constvars) != len(consts):
170+
assert False, (constvars, consts) # DO_NOT_SUBMIT
171+
assert len(constvars) == len(consts), (constvars, consts)
163172

164173
def __str__(self):
165174
return str(self.pretty_print())
@@ -187,6 +196,7 @@ def replace(self, **kwargs):
187196
effects=kwargs.pop("effects", self.effects),
188197
debug_info=kwargs.pop("debug_info", self.debug_info),
189198
is_high=kwargs.pop("is_high", self.is_high),
199+
consts=kwargs.pop("consts", self.consts),
190200
)
191201
if kwargs:
192202
raise ValueError(f"Unknown keyword arguments: {kwargs}")
@@ -223,11 +233,14 @@ class ClosedJaxpr:
223233
jaxpr = property(lambda self: self._jaxpr)
224234
consts = property(lambda self: self._consts)
225235

226-
def __init__(self, jaxpr: Jaxpr, consts: Sequence):
236+
def __init__(self, jaxpr: Jaxpr, consts: Sequence[Value]):
227237
assert len(consts) == len(jaxpr.constvars)
228238
# assert not any(isinstance(c, Tracer) for c in consts) # TODO(mattjj): enable
229239
self._jaxpr = jaxpr
230240
self._consts = list(consts)
241+
if jaxpr.consts != self._consts: # DO_NOT_SUBMIT
242+
assert False, (jaxpr.consts, self._consts) # TODO(necula): remove, when we remove ClosedJaxpr
243+
assert jaxpr.consts == self._consts, (jaxpr.consts, self._consts) # TODO(necula): remove, when we remove ClosedJaxpr
231244

232245
@property
233246
def in_avals(self):

jax/_src/interpreters/ad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1224,7 +1224,7 @@ def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_
12241224
(*constvars, *new_invars), (*constvars, *jaxpr.jaxpr.invars),
12251225
jaxpr.jaxpr.effects)
12261226
new_jaxpr = core.Jaxpr(constvars, new_invars, new_outvars, jaxpr.jaxpr.eqns,
1227-
new_effects, new_debug_info)
1227+
new_effects, new_debug_info, consts=jaxpr.consts)
12281228
return core.ClosedJaxpr(new_jaxpr, jaxpr.consts)
12291229

12301230
def _perm(primal_counts: Sequence[int], tangent_counts: Sequence[int],

jax/_src/interpreters/partial_eval.py

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ def sort_key(t):
880880
jaxpr_effects = make_jaxpr_effects(const_vars, invars, outvars, eqns)
881881
jaxpr = Jaxpr(const_vars, invars, # type: ignore[arg-type]
882882
outvars, eqns, jaxpr_effects,
883-
debug_info)
883+
debug_info, consts=const_vals)
884884
config.enable_checks.value and core.check_jaxpr(jaxpr)
885885
# del getvar # needed to avoid cyclic-reference closure, apparently!
886886
return jaxpr, const_vals, env_vals
@@ -897,35 +897,27 @@ def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
897897
dbg = jaxpr.debug_info._replace(
898898
arg_names=("",) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names)
899899
lifted_jaxpr = jaxpr.replace(
900-
constvars=(), invars=jaxpr.constvars + jaxpr.invars, debug_info=dbg)
900+
constvars=(), invars=jaxpr.constvars + jaxpr.invars, debug_info=dbg, consts=())
901901
config.enable_checks.value and core.check_jaxpr(lifted_jaxpr)
902902
return lifted_jaxpr
903903

904-
@weakref_lru_cache
905-
def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr:
906-
"""Move n invars to constvars. Like an inverse of convert_constvars_jaxpr."""
904+
# DO_NOT_SUBMIT
905+
# @weakref_lru_cache
906+
def convert_invars_to_constvars(jaxpr: Jaxpr, n: int, consts: Sequence[core.Value]) -> Jaxpr:
907+
"""Move the first n invars to constvars. Like an inverse of convert_constvars_jaxpr."""
907908
if n == 0:
908909
return jaxpr.replace() # 'return jaxpr' would create cache reference cycle
909910
config.enable_checks.value and core.check_jaxpr(jaxpr)
911+
assert n == len(consts)
912+
assert 0 == len(jaxpr.constvars)
910913
constvars, invars = split_list(jaxpr.invars, [n])
911914
dbg = jaxpr.debug_info._replace(
912915
arg_names=jaxpr.debug_info.arg_names[n:])
913916
lifted_jaxpr = jaxpr.replace(constvars=tuple(constvars), invars=invars,
914-
debug_info=dbg)
917+
debug_info=dbg, consts=consts)
915918
config.enable_checks.value and core.check_jaxpr(lifted_jaxpr)
916919
return lifted_jaxpr
917920

918-
def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr:
919-
if any(isinstance(eff, effects.JaxprInputEffect) for eff in jaxpr.effects):
920-
raise NotImplementedError
921-
config.enable_checks.value and core.check_jaxpr(jaxpr)
922-
env_vars, invars = split_list(jaxpr.invars, [num_env_vars])
923-
converted_jaxpr = jaxpr.replace(constvars=jaxpr.constvars + env_vars,
924-
invars=invars)
925-
config.enable_checks.value and core.check_jaxpr(converted_jaxpr)
926-
return converted_jaxpr
927-
928-
929921
def partial_eval_jaxpr_nounits(
930922
jaxpr: ClosedJaxpr, unknowns: Sequence[bool],
931923
instantiate: bool | Sequence[bool],
@@ -1459,7 +1451,9 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool],
14591451
new_jaxpr, used_inputs_ = dce_jaxpr(jaxpr_, used_outputs, instantiate)
14601452
used_consts, used_inputs = split_list(used_inputs_, [len(jaxpr.constvars)])
14611453
if sum(used_consts):
1462-
new_jaxpr = convert_invars_to_constvars(new_jaxpr, sum(used_consts))
1454+
new_consts = tuple(c for u, c in zip(used_consts, jaxpr.consts) if u)
1455+
new_jaxpr = convert_invars_to_constvars(new_jaxpr, sum(used_consts),
1456+
new_consts)
14631457
return new_jaxpr, used_consts, used_inputs
14641458

14651459

@@ -1805,13 +1799,14 @@ def to_jaxpr(
18051799
v.final_qdd = qdd.cur_val
18061800

18071801
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns, jaxpr_effects,
1808-
debug_info, self.is_high)
1802+
debug_info, self.is_high, consts=constvals)
18091803
jaxpr, constvals = _drop_unused_vars(jaxpr, constvals)
18101804
init_trees = [tree_structure(init_val) for init_val in self.attrs_inits]
18111805
return jaxpr, list(constvals), zip(init_trees, end_trees, self.attrs_tracked)
18121806

18131807
def to_jaxpr2(self, out_tracers: Sequence[core.Tracer],
1814-
debug_info: core.DebugInfo):
1808+
debug_info: core.DebugInfo
1809+
) -> tuple[core.Jaxpr, tuple[core.AbstractValue, ...]]:
18151810
# It's not necessary, but we keep the tracer-to-var mapping injective:
18161811
vars = [v for v in self.tracer_to_var.values() if not isinstance(v, Literal)]
18171812
assert len(vars) == len(set(vars))
@@ -1820,7 +1815,7 @@ def to_jaxpr2(self, out_tracers: Sequence[core.Tracer],
18201815
jaxpr_effects = make_jaxpr_effects(constvars, self.invars, expl_outvars,
18211816
self.eqns)
18221817
jaxpr = Jaxpr(constvars, self.invars, expl_outvars, self.eqns,
1823-
jaxpr_effects, debug_info)
1818+
jaxpr_effects, debug_info, consts=constvals)
18241819
# We can't run check_jaxpr until after we normalize.
18251820
jaxpr, constvals = _drop_unused_vars(jaxpr, constvals)
18261821
jaxpr, out_type = _add_implicit_outputs(jaxpr)
@@ -1886,9 +1881,10 @@ def vars(atom: Atom) -> list[Var]:
18861881
used.update(v for atom in eqn.invars for v in vars(atom))
18871882
cvars, constvals = unzip2(
18881883
(v, val) for v, val in zip(jaxpr.constvars, constvals) if v in used)
1889-
jaxpr._constvars = list(cvars)
1890-
jaxpr._effects = make_jaxpr_effects(jaxpr.constvars, jaxpr.invars,
1891-
jaxpr.outvars, jaxpr.eqns)
1884+
jaxpr = jaxpr.replace(constvars=list(cvars),
1885+
effects=make_jaxpr_effects(jaxpr.constvars, jaxpr.invars,
1886+
jaxpr.outvars, jaxpr.eqns),
1887+
consts=list(constvals))
18921888
return jaxpr, list(constvals)
18931889

18941890

@@ -2831,5 +2827,5 @@ def convert_const_himutables(jaxpr):
28312827
effects = make_jaxpr_effects(constvars, invars, jaxpr.jaxpr.outvars,
28322828
jaxpr.jaxpr.eqns)
28332829
new_jaxpr = jaxpr.jaxpr.replace(constvars=constvars, invars=invars,
2834-
effects=effects)
2830+
effects=effects, consts=constvals)
28352831
return jaxpr.replace(jaxpr=new_jaxpr, consts=constvals), in_mutables

jax/_src/interpreters/pxla.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,6 +1734,8 @@ def _dce_jaxpr(closed_jaxpr, api_name, fun_name,
17341734
in_avals = closed_jaxpr.in_avals
17351735

17361736
if (keep_unused or auto_spmd_lowering or
1737+
# Don't drop inputs with symbolic shapes, because we may need them
1738+
# to infer the symbolic dimensions at call sites.
17371739
any(hasattr(a, "shape") and not core.is_constant_shape(a.shape)
17381740
for a in in_avals)):
17391741
kept_var_idx = set(range(len(in_avals)))

jax/_src/lax/control_flow/common.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,18 @@ def _initial_style_jaxprs_with_common_consts(
175175
canonical_non_ref_indices.append(tuple(non_ref_indices))
176176

177177
consts = [*canonical_refs, *canonical_non_refs]
178-
jaxprs = tuple(_pad_jaxpr_constvars(jaxpr, i, (*canonical_ref_avals,), (*canonical_ref_indices,), (*canonical_non_ref_avals,), (*canonical_non_ref_indices,))
178+
jaxprs = tuple(_pad_jaxpr_constvars(jaxpr, i, consts,
179+
(*canonical_ref_avals,), (*canonical_ref_indices,), (*canonical_non_ref_avals,), (*canonical_non_ref_indices,))
179180
for i, jaxpr in enumerate(jaxprs))
180181
return jaxprs, consts, all_out_trees
181182

182-
@weakref_lru_cache
183-
def _pad_jaxpr_constvars(jaxpr, i, canonical_ref_avals, canonical_ref_indices,
184-
canonical_non_ref_avals, canonical_non_ref_indices):
183+
# DO_NOT_SUBMIT
184+
# @weakref_lru_cache
185+
def _pad_jaxpr_constvars(jaxpr: core.Jaxpr, i,
186+
consts: Sequence[core.Value],
187+
canonical_ref_avals, canonical_ref_indices,
188+
canonical_non_ref_avals, canonical_non_ref_indices
189+
) -> core.ClosedJaxpr:
185190
is_ref = [isinstance(v.aval, state.AbstractRef) for v in jaxpr.constvars]
186191
nonref_constvars, ref_constvars = partition_list(is_ref, jaxpr.constvars)
187192
padded_ref_constvars = map(core.Var, canonical_ref_avals)
@@ -191,7 +196,7 @@ def _pad_jaxpr_constvars(jaxpr, i, canonical_ref_avals, canonical_ref_indices,
191196
for canonical_id, non_ref_var in zip(canonical_non_ref_indices[i], nonref_constvars):
192197
padded_non_ref_constvars[canonical_id] = non_ref_var
193198
constvars = [*padded_ref_constvars, *padded_non_ref_constvars]
194-
jaxpr = jaxpr.replace(constvars=constvars)
199+
jaxpr = jaxpr.replace(constvars=constvars, consts=consts)
195200
effects = pe.make_jaxpr_effects(jaxpr.constvars, jaxpr.invars,
196201
jaxpr.outvars, jaxpr.eqns)
197202
jaxpr = jaxpr.replace(effects=effects)

jax/interpreters/partial_eval.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
config as config,
4949
const_fold_rules as const_fold_rules,
5050
convert_constvars_jaxpr as convert_constvars_jaxpr,
51-
convert_envvars_to_constvars as convert_envvars_to_constvars,
5251
convert_invars_to_constvars as convert_invars_to_constvars,
5352
custom_partial_eval_rules as custom_partial_eval_rules,
5453
custom_staging_rules as custom_staging_rules,

tests/core_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,13 +402,13 @@ def setUp(self):
402402
super().setUp()
403403
lax_control_flow._initial_style_open_jaxpr.cache_clear()
404404
lax_control_flow._initial_style_jaxpr.cache_clear()
405-
lax_control_flow.common._pad_jaxpr_constvars.cache_clear()
405+
# lax_control_flow.common._pad_jaxpr_constvars.cache_clear() # DO_NOT_SUBMIT
406406

407407
def tearDown(self):
408408
super().tearDown()
409409
lax_control_flow._initial_style_open_jaxpr.cache_clear()
410410
lax_control_flow._initial_style_jaxpr.cache_clear()
411-
lax_control_flow.common._pad_jaxpr_constvars.cache_clear()
411+
# lax_control_flow.common._pad_jaxpr_constvars.cache_clear() # DO_NOT_SUBMIT
412412

413413
def test_check_jaxpr_correct(self):
414414
jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr

tests/lax_control_flow_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def setUp(self):
177177
super().setUp()
178178
lax_control_flow._initial_style_open_jaxpr.cache_clear()
179179
lax_control_flow._initial_style_jaxpr.cache_clear()
180-
lax_control_flow.common._pad_jaxpr_constvars.cache_clear()
180+
# lax_control_flow.common._pad_jaxpr_constvars.cache_clear() # DO_NOT_SUBMIT
181181

182182
def testCallableErrors(self):
183183
not_callable = 42

0 commit comments

Comments
 (0)