[mutable-arrays] re-land #29353 #29421
Merged
+237
−108
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Re-land #29353 after rollback
The code in #29353 had a bug: on this line we forgot to offset by
len(jaxpr.consts)
, since theint
entries ofin_fwd_res: list[int | None]
index into[*jaxpr.consts, *known_inputs]
whereknown_inputs = [t.pval.get_known() for t in tracers if t.pval.is_known()]
.In #29353 I made the classic hasty blunder of not writing a systematic test for the new code. This updated PR has a very thorough systematic test. I verified it catches the bug in #29353 with
JAX_NUM_GENERATED_CASES=100
. This new PR passes all tests withJAX_NUM_GENERATED_CASES=9999
, which is the damage limit.