Skip to content

Commit 7f3bb63

Browse files
committed
make config.FLAGS.jax_enable_foo an error
1 parent 8acad26 commit 7f3bb63

File tree

3 files changed

+16
-14
lines changed

3 files changed

+16
-14
lines changed

jax/config.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(self):
5151
self.meta = {}
5252
self.FLAGS = NameSpace(self.read)
5353
self.use_absl = False
54+
self._contextmanager_flags = set()
5455

5556
# TODO(mattjj): delete these when only omnistaging is available
5657
self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', True)
@@ -71,6 +72,13 @@ def update(self, name, val):
7172
lib.jax_jit.global_state().enable_x64 = val
7273

7374
def read(self, name):
75+
if name in self._contextmanager_flags:
76+
raise AttributeError(
77+
"For flags with a corresponding contextmanager, read their value "
78+
f"via e.g. `config.{name}` rather than `config.FLAGS.{name}`.")
79+
return self._read(name)
80+
81+
def _read(self, name):
7482
if self.use_absl:
7583
return getattr(self.absl_flags.FLAGS, name)
7684
else:
@@ -193,23 +201,17 @@ def define_bool_state(self, name: str, default: bool, help: str):
193201
with enable_foo(True):
194202
...
195203
196-
Accessing ``config.FLAGS.jax_enable_foo`` is different from accessing the
197-
thread-local state value via ``config.jax_enable_foo``: the former reads the
198-
flag value determined set by the environment variable or command-line flag
199-
and does not read the thread-local state, whereas the latter reads the
200-
thread-local state value managed by the contextmanager. Think of the
201-
contextmanager state as a layer on top of the flag value: if no
202-
contextmanager is in use then ``config.jax_enable_foo`` reflects the flag
203-
value ``config.FLAGS.jax_enable_foo``, whereas if a contextmanager is in use
204-
then only ``config.jax_enable_foo`` is updated. So in general using
205-
``config.jax_enable_foo`` is best.
204+
The value of the thread-local state or flag can be accessed via
205+
``config.jax_enable_foo``. Reading it via ``config.FLAGS.jax_enable_foo`` is
206+
an error.
206207
"""
207208
name = name.lower()
208209
self.DEFINE_bool(name, bool_env(name.upper(), default), help)
210+
self._contextmanager_flags.add(name)
209211

210212
def get_state(self):
211213
val = getattr(_thread_local_state, name, unset)
212-
return val if val is not unset else self.read(name)
214+
return val if val is not unset else self._read(name)
213215
setattr(Config, name, property(get_state))
214216

215217
@contextlib.contextmanager

tests/api_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2334,7 +2334,7 @@ def test_leak_checker_catches_a_sublevel_leak(self):
23342334
if not config.omnistaging_enabled:
23352335
raise unittest.SkipTest("test only works with omnistaging")
23362336

2337-
with core.checking_leaks():
2337+
with jax.checking_leaks():
23382338
@jit
23392339
def f(x):
23402340
lst = []

tests/debug_nans_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
class DebugNaNsTest(jtu.JaxTestCase):
3131

3232
def setUp(self):
33-
self.cfg = config.read("jax_debug_nans")
33+
self.cfg = config._read("jax_debug_nans")
3434
config.update("jax_debug_nans", True)
3535

3636
def tearDown(self):
@@ -144,7 +144,7 @@ def testPjit(self):
144144
class DebugInfsTest(jtu.JaxTestCase):
145145

146146
def setUp(self):
147-
self.cfg = config.read("jax_debug_infs")
147+
self.cfg = config._read("jax_debug_infs")
148148
config.update("jax_debug_infs", True)
149149

150150
def tearDown(self):

0 commit comments

Comments
 (0)