17
17
from absl .testing import absltest
18
18
from absl .testing import parameterized
19
19
from functools import partial
20
+ import itertools as it
20
21
import numpy as np
21
22
import jax
22
23
from jax ._src import core
@@ -367,14 +368,16 @@ def loss(x, y):
367
368
368
369
@parameterized .parameters ([False , True ])
369
370
def test_custom_vjp_grad_stats_plumbing_basic (self , jit ):
370
- @jax .jit
371
371
def primal (grads_ref , x ): # note: jit-abstracted!
372
372
x = jnp .sin (x )
373
373
x = stash_grads (grads_ref , x )
374
374
x = jnp .sin (x )
375
375
x = stash_grads (grads_ref , x ) # ignored, order-preserved
376
376
return x
377
377
378
+ if jit :
379
+ primal = jax .jit (primal )
380
+
378
381
@jax .custom_vjp
379
382
def stash_grads (grads_ref , x ):
380
383
return x
@@ -389,18 +392,22 @@ def stash_grads_bwd(grads_ref, g):
389
392
jax .grad (primal , 1 )(grads_ref , jnp .float32 (1.0 ))
390
393
self .assertAllClose (grads_ref [...], jnp .cos (jnp .sin (1. )), check_dtypes = False )
391
394
392
- @parameterized .parameters ([False , True ])
393
- def test_custom_vjp_grad_stats_plumbing_scan (self , jit ):
394
- @jax .jit
395
+ @parameterized .parameters (it .product ([False , True ], repeat = 2 ))
396
+ def test_custom_vjp_grad_stats_plumbing_scan (self , jit , remat ):
395
397
def primal (grads_ref , x ): # note: jit-abstracted!
396
398
def body (x , _ ):
397
399
x = jnp .sin (x )
398
400
x = stash_grads (grads_ref , x )
399
401
x = jnp .sin (x )
400
402
return x , ()
403
+ if remat :
404
+ body = jax .remat (body )
401
405
x , () = jax .lax .scan (body , x , None , length = 1 )
402
406
return x
403
407
408
+ if jit :
409
+ primal = jax .jit (primal )
410
+
404
411
@jax .custom_vjp
405
412
def stash_grads (grads_ref , x ):
406
413
return x
@@ -415,6 +422,35 @@ def stash_grads_bwd(grads_ref, g):
415
422
jax .grad (primal , argnums = 1 )(grads_ref , jnp .float32 (1.0 ))
416
423
self .assertAllClose (grads_ref [...], jnp .cos (jnp .sin (1. )), check_dtypes = False )
417
424
425
+ # TODO(mattjj,dougalm): this errors, which may or may not be desired behavior
426
+ # def test_remat_basic(self):
427
+ # @jax.remat
428
+ # def f(x_ref, y):
429
+ # x_ref[...] += 1
430
+ # return y
431
+
432
+ # x_ref = core.mutable_array(0)
433
+ # jax.grad(f, 1)(x_ref, 3.14)
434
+
435
+ def test_remat_grad_stats_plumbing_basic (self ):
436
+ @jax .remat
437
+ def f (x_ref , y ):
438
+ stash_grads (x_ref , y )
439
+ return y
440
+
441
+ @jax .custom_vjp
442
+ def stash_grads (grads_ref , x ):
443
+ return x
444
+ def stash_grads_fwd (grads_ref , x ):
445
+ return x , grads_ref
446
+ def stash_grads_bwd (grads_ref , g ):
447
+ grads_ref [...] = g
448
+ return None , g
449
+ stash_grads .defvjp (stash_grads_fwd , stash_grads_bwd )
450
+
451
+ x_ref = core .mutable_array (0 )
452
+ jax .grad (f , 1 )(x_ref , 3.14 )
453
+
418
454
419
455
@jtu .with_config (jax_mutable_array_checks = True )
420
456
class MutableArrayErrorsTest (jtu .JaxTestCase ):
0 commit comments