@@ -51,6 +51,7 @@ def __init__(self):
51
51
self .meta = {}
52
52
self .FLAGS = NameSpace (self .read )
53
53
self .use_absl = False
54
+ self ._contextmanager_flags = set ()
54
55
55
56
# TODO(mattjj): delete these when only omnistaging is available
56
57
self .omnistaging_enabled = bool_env ('JAX_OMNISTAGING' , True )
@@ -71,6 +72,13 @@ def update(self, name, val):
71
72
lib .jax_jit .global_state ().enable_x64 = val
72
73
73
74
def read (self , name ):
75
+ if name in self ._contextmanager_flags :
76
+ raise AttributeError (
77
+ "For flags with a corresponding contextmanager, read their value "
78
+ f"via e.g. `config.{ name } ` rather than `config.FLAGS.{ name } `." )
79
+ return self ._read (name )
80
+
81
+ def _read (self , name ):
74
82
if self .use_absl :
75
83
return getattr (self .absl_flags .FLAGS , name )
76
84
else :
@@ -193,23 +201,17 @@ def define_bool_state(self, name: str, default: bool, help: str):
193
201
with enable_foo(True):
194
202
...
195
203
196
- Accessing ``config.FLAGS.jax_enable_foo`` is different from accessing the
197
- thread-local state value via ``config.jax_enable_foo``: the former reads the
198
- flag value determined set by the environment variable or command-line flag
199
- and does not read the thread-local state, whereas the latter reads the
200
- thread-local state value managed by the contextmanager. Think of the
201
- contextmanager state as a layer on top of the flag value: if no
202
- contextmanager is in use then ``config.jax_enable_foo`` reflects the flag
203
- value ``config.FLAGS.jax_enable_foo``, whereas if a contextmanager is in use
204
- then only ``config.jax_enable_foo`` is updated. So in general using
205
- ``config.jax_enable_foo`` is best.
204
+ The value of the thread-local state or flag can be accessed via
205
+ ``config.jax_enable_foo``. Reading it via ``config.FLAGS.jax_enable_foo`` is
206
+ an error.
206
207
"""
207
208
name = name .lower ()
208
209
self .DEFINE_bool (name , bool_env (name .upper (), default ), help )
210
+ self ._contextmanager_flags .add (name )
209
211
210
212
def get_state (self ):
211
213
val = getattr (_thread_local_state , name , unset )
212
- return val if val is not unset else self .read (name )
214
+ return val if val is not unset else self ._read (name )
213
215
setattr (Config , name , property (get_state ))
214
216
215
217
@contextlib .contextmanager
0 commit comments