Skip to content

Commit ab848e7

Browse files
committed
[mutable-arrays] remat discharge rule
1 parent 56f3293 commit ab848e7

File tree

3 files changed

+60
-4
lines changed

3 files changed

+60
-4
lines changed

jax/_src/core.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3080,6 +3080,11 @@ def write(v: Var, a: AvalQDD) -> None:
30803080
f"from source: {src}"])
30813081
raise JaxprTypeError(msg, eqn_idx) from None
30823082

3083+
# Check no returned refs TODO(mattjj); improve this error message
3084+
from jax._src.state.types import AbstractRef # pytype: disable=import-error
3085+
for v in jaxpr.outvars:
3086+
if isinstance(v.aval, AbstractRef): raise TypeError("returned ref")
3087+
30833088
# TODO(mattjj): include output type annotation on jaxpr and check it here
30843089
foreach(read, jaxpr.outvars)
30853090

jax/_src/state/discharge.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from typing import Any, Protocol, TypeVar
2323

2424
from jax._src import ad_util
25+
from jax._src import ad_checkpoint
2526
from jax._src import api_util
2627
from jax._src import core
2728
from jax._src import linear_util as lu
@@ -1179,3 +1180,17 @@ def _pjit_state_discharge_rule(
11791180
sentinel = object()
11801181
assert next(ref_vals_iter, sentinel) is sentinel
11811182
return new_invals, out_vals
1183+
1184+
1185+
@register_discharge_rule(ad_checkpoint.remat_p)
1186+
def _remat_state_discharge_rule(
1187+
in_avals, out_avals, *args, jaxpr, **params):
1188+
discharged_jaxpr, () = discharge_state(jaxpr, [])
1189+
out_vals_ref_vals = ad_checkpoint.remat_p.bind(
1190+
*args, jaxpr=discharged_jaxpr, **params)
1191+
out_vals, ref_vals = split_list(out_vals_ref_vals, [len(jaxpr.outvars)])
1192+
ref_vals_ = iter(ref_vals)
1193+
new_invals = [next(ref_vals_) if isinstance(a, AbstractRef) else None
1194+
for a in in_avals]
1195+
assert next(ref_vals_, None) is None
1196+
return new_invals, out_vals

tests/mutable_array_test.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from absl.testing import absltest
1818
from absl.testing import parameterized
1919
from functools import partial
20+
import itertools as it
2021
import numpy as np
2122
import jax
2223
from jax._src import core
@@ -367,14 +368,16 @@ def loss(x, y):
367368

368369
@parameterized.parameters([False, True])
369370
def test_custom_vjp_grad_stats_plumbing_basic(self, jit):
370-
@jax.jit
371371
def primal(grads_ref, x): # note: jit-abstracted!
372372
x = jnp.sin(x)
373373
x = stash_grads(grads_ref, x)
374374
x = jnp.sin(x)
375375
x = stash_grads(grads_ref, x) # ignored, order-preserved
376376
return x
377377

378+
if jit:
379+
primal = jax.jit(primal)
380+
378381
@jax.custom_vjp
379382
def stash_grads(grads_ref, x):
380383
return x
@@ -389,18 +392,22 @@ def stash_grads_bwd(grads_ref, g):
389392
jax.grad(primal, 1)(grads_ref, jnp.float32(1.0))
390393
self.assertAllClose(grads_ref[...], jnp.cos(jnp.sin(1.)), check_dtypes=False)
391394

392-
@parameterized.parameters([False, True])
393-
def test_custom_vjp_grad_stats_plumbing_scan(self, jit):
394-
@jax.jit
395+
@parameterized.parameters(it.product([False, True], repeat=2))
396+
def test_custom_vjp_grad_stats_plumbing_scan(self, jit, remat):
395397
def primal(grads_ref, x): # note: jit-abstracted!
396398
def body(x, _):
397399
x = jnp.sin(x)
398400
x = stash_grads(grads_ref, x)
399401
x = jnp.sin(x)
400402
return x, ()
403+
if remat:
404+
body = jax.remat(body)
401405
x, () = jax.lax.scan(body, x, None, length=1)
402406
return x
403407

408+
if jit:
409+
primal = jax.jit(primal)
410+
404411
@jax.custom_vjp
405412
def stash_grads(grads_ref, x):
406413
return x
@@ -415,6 +422,35 @@ def stash_grads_bwd(grads_ref, g):
415422
jax.grad(primal, argnums=1)(grads_ref, jnp.float32(1.0))
416423
self.assertAllClose(grads_ref[...], jnp.cos(jnp.sin(1.)), check_dtypes=False)
417424

425+
# TODO(mattjj,dougalm): this errors, which may or may not be desired behavior
426+
# def test_remat_basic(self):
427+
# @jax.remat
428+
# def f(x_ref, y):
429+
# x_ref[...] += 1
430+
# return y
431+
432+
# x_ref = core.mutable_array(0)
433+
# jax.grad(f, 1)(x_ref, 3.14)
434+
435+
def test_remat_grad_stats_plumbing_basic(self):
436+
@jax.remat
437+
def f(x_ref, y):
438+
stash_grads(x_ref, y)
439+
return y
440+
441+
@jax.custom_vjp
442+
def stash_grads(grads_ref, x):
443+
return x
444+
def stash_grads_fwd(grads_ref, x):
445+
return x, grads_ref
446+
def stash_grads_bwd(grads_ref, g):
447+
grads_ref[...] = g
448+
return None, g
449+
stash_grads.defvjp(stash_grads_fwd, stash_grads_bwd)
450+
451+
x_ref = core.mutable_array(0)
452+
jax.grad(f, 1)(x_ref, 3.14)
453+
418454

419455
@jtu.with_config(jax_mutable_array_checks=True)
420456
class MutableArrayErrorsTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)