Skip to content

Commit c2ef1fc

Browse files
committed
add dot precision flag, enum state utility
1 parent f76df86 commit c2ef1fc

File tree

7 files changed

+213
-48
lines changed

7 files changed

+213
-48
lines changed

jax/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131

3232
# flake8: noqa: F401
3333
from .config import (config, enable_checks, check_tracer_leaks, checking_leaks,
34-
debug_nans, debug_infs, log_compiles)
34+
debug_nans, debug_infs, log_compiles,
35+
default_dot_precision, numpy_rank_promotion)
3536
from .api import (
3637
ad, # TODO(phawkins): update users to avoid this.
3738
argnums_partial, # TODO(phawkins): update Haiku to not use this.

jax/_src/lax/lax.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -496,7 +496,13 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array:
496496
Precision = xla_client.PrecisionConfig.Precision
497497
Precision.__str__ = lambda precision: precision.name
498498
PrecisionType = Any
499-
PrecisionLike = Union[None, PrecisionType, Tuple[PrecisionType, PrecisionType]]
499+
PrecisionLike = Union[None, str, PrecisionType,
500+
Tuple[PrecisionType, PrecisionType]]
501+
_precision_strings = {
502+
'bfloat16': Precision.DEFAULT,
503+
'tensorfloat32': Precision.HIGH,
504+
'float32': Precision.HIGHEST,
505+
}
500506

501507

502508
class ConvDimensionNumbers(NamedTuple):
@@ -551,9 +557,10 @@ def conv_general_dilated(
551557
feature_group_count: integer, default 1. See XLA HLO docs.
552558
batch_group_count: integer, default 1. See XLA HLO docs.
553559
precision: Optional. Either ``None``, which means the default precision for
554-
the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
555-
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
556-
``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
560+
the backend, a string ('bfloat16', 'tensorfloat32', or 'float32'), a
561+
``lax.Precision`` enum value (``Precision.DEFAULT``, ``Precision.HIGH`` or
562+
``Precision.HIGHEST``) or a tuple of two ``lax.Precision`` enums
563+
indicating precision of ``lhs``` and ``rhs``.
557564
558565
Returns:
559566
An array containing the convolution result.
@@ -6378,16 +6385,19 @@ def remaining(original, *removed_lists):
63786385

63796386
def _canonicalize_precision(precision):
63806387
if precision is None:
6381-
return None
6382-
if isinstance(precision, Precision) or (
6383-
isinstance(precision, tuple)
6384-
and len(precision) == 2
6385-
and all(isinstance(p, Precision) for p in precision)
6386-
):
6388+
return _precision_strings.get(config.jax_default_dot_precision)
6389+
elif isinstance(precision, str) and precision in _precision_strings:
6390+
return _precision_strings.get(precision)
6391+
elif isinstance(precision, Precision):
6392+
return precision
6393+
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
6394+
all(isinstance(p, Precision) for p in precision)):
63876395
return precision
63886396
else:
6389-
raise ValueError("Precision argument must be None, a lax.Precision value "
6390-
f"or a tuple of two lax.Precision values; got {precision}")
6397+
raise ValueError(
6398+
f"Precision argument must be None, a string in {_precision_strings}, "
6399+
"a lax.Precision value or a tuple of two lax.Precision values; "
6400+
f"got {precision}.")
63916401

63926402

63936403
def conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers

jax/_src/numpy/lax_numpy.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,6 @@
5454
canonicalize_axis as _canonicalize_axis, maybe_named_axis)
5555
from jax.tree_util import tree_leaves, tree_flatten, tree_map
5656

57-
FLAGS = flags.FLAGS
58-
flags.DEFINE_enum(
59-
'jax_numpy_rank_promotion', os.getenv('JAX_NUMPY_RANK_PROMOTION', 'allow'),
60-
enum_values=['allow', 'warn', 'raise'],
61-
help=
62-
'Control NumPy-style automatic rank promotion broadcasting '
63-
'("allow", "warn", or "raise").')
64-
6557
newaxis = None
6658

6759
# Common docstring additions:
@@ -246,20 +238,20 @@ def _promote_shapes(fun_name, *args):
246238
if not nonscalar_ranks or len(set(nonscalar_ranks)) == 1:
247239
return args
248240
else:
249-
if FLAGS.jax_numpy_rank_promotion != "allow":
241+
if config.jax_numpy_rank_promotion != "allow":
250242
_rank_promotion_warning_or_error(fun_name, shapes)
251243
result_rank = len(lax.broadcast_shapes(*shapes))
252244
return [broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
253245
for arg, shp in zip(args, shapes)]
254246

255247
def _rank_promotion_warning_or_error(fun_name, shapes):
256-
if FLAGS.jax_numpy_rank_promotion == "warn":
248+
if config.jax_numpy_rank_promotion == "warn":
257249
msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
258250
"Set the jax_numpy_rank_promotion config option to 'allow' to "
259251
"disable this warning; for more information, see "
260252
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
261253
warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes))))
262-
elif FLAGS.jax_numpy_rank_promotion == "raise":
254+
elif config.jax_numpy_rank_promotion == "raise":
263255
msg = ("Operands could not be broadcast together for {} on shapes {} "
264256
"and with the config option jax_numpy_rank_promotion='raise'. "
265257
"For more information, see "

jax/api.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,9 @@ def f_jitted(*args, **kwargs):
356356
# TODO(jblespiau): We can remove `config.x64_enabled` when jaxlib 0.1.62 is
357357
# the minimal version. NOTE(mattjj): minversion 0.1.62 didn't work...
358358
context = (getattr(core.thread_local_state.trace_state.trace_stack,
359-
"dynamic", None), config.x64_enabled)
359+
"dynamic", None),
360+
config.x64_enabled,
361+
config.jax_default_dot_precision)
360362
# TODO(jblespiau): Move this to C++.
361363
if (config.jax_debug_nans or config.jax_debug_infs) and not _jit_is_disabled():
362364
device_arrays = cpp_jitted_f(context, *args, **kwargs)
@@ -2442,6 +2444,21 @@ def named_f(*args, **kwargs):
24422444

24432445
return named_f
24442446

2447+
2448+
def invertible(fun: Callable) -> Callable:
2449+
"""Asserts that the decorated function is invertible.
2450+
2451+
Applying reverse-mode AD to a decorated function will use a more memory efficient
2452+
procedure than usual, which will reconstruct the necessary intermediate values
2453+
by inverting the function. Note that this might degrade the numerical accuracy of
2454+
obtained gradients if the inverse is unstable.
2455+
2456+
Args:
2457+
fun: The function assumed to be invertible.
2458+
"""
2459+
return iad.invertible(fun)
2460+
2461+
24452462
# TODO(mattjj): delete everything below here (deprecated custom_transforms)
24462463

24472464
class CustomTransformsFunction(object):
@@ -2583,16 +2600,3 @@ def vjpfun(ct):
25832600
for x, vjp in zip(primals, vjprules))
25842601
return ans, vjpfun
25852602
defvjp_all(fun, custom_vjp)
2586-
2587-
def invertible(fun: Callable) -> Callable:
2588-
"""Asserts that the decorated function is invertible.
2589-
2590-
Applying reverse-mode AD to a decorated function will use a more memory efficient
2591-
procedure than usual, which will reconstruct the necessary intermediate values
2592-
by inverting the function. Note that this might degrade the numerical accuracy of
2593-
obtained gradients if the inverse is unstable.
2594-
2595-
Args:
2596-
fun: The function assumed to be invertible.
2597-
"""
2598-
return iad.invertible(fun)

jax/config.py

Lines changed: 92 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# Copyright 2018 Google LLC
2-
#
3-
# Licensed under the Apache License, Version 2.0 (the "License");
1+
# Copyright 2018 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
55
# You may obtain a copy of the License at
66
#
@@ -17,6 +17,7 @@
1717
import os
1818
import sys
1919
import threading
20+
from typing import List, Optional
2021

2122
from jax import lib
2223

@@ -51,7 +52,7 @@ class Config:
5152
def __init__(self):
5253
self.values = {}
5354
self.meta = {}
54-
self.FLAGS = NameSpace(self.read)
55+
self.FLAGS = NameSpace(self.read, self.update)
5556
self.use_absl = False
5657
self._contextmanager_flags = set()
5758

@@ -219,6 +220,54 @@ def get_state(self):
219220
@contextlib.contextmanager
220221
def set_state(new_val: bool):
221222
prev_val = getattr(_thread_local_state, name, unset)
223+
setattr(_thread_local_state, name, bool(new_val))
224+
try:
225+
yield
226+
finally:
227+
if prev_val is unset:
228+
delattr(_thread_local_state, name)
229+
else:
230+
setattr(_thread_local_state, name, prev_val)
231+
set_state.__name__ = name[4:] if name.startswith('jax_') else name
232+
set_state.__doc__ = f"Context manager for `{name}` config option.\n\n{help}"
233+
return set_state
234+
235+
def define_enum_state(self, name: str, enum_values: List[str],
236+
default: Optional[str], help: str):
237+
"""Set up thread-local state and return a contextmanager for managing it.
238+
239+
Args:
240+
name: string, converted to lowercase to define the name of the config
241+
option (and absl flag). It is converted to uppercase to define the
242+
corresponding shell environment variable.
243+
enum_values: list of strings representing the possible values for the
244+
option.
245+
default: optional string, default value.
246+
help: string, used to populate the flag help information as well as the
247+
docstring of the returned context manager.
248+
249+
Returns:
250+
A contextmanager to control the thread-local state value.
251+
252+
See docstring for ``define_bool_state``.
253+
"""
254+
name = name.lower()
255+
self.DEFINE_enum(name, os.getenv(name.upper(), default),
256+
enum_values=enum_values, help=help)
257+
self._contextmanager_flags.add(name)
258+
259+
def get_state(self):
260+
val = getattr(_thread_local_state, name, unset)
261+
return val if val is not unset else self._read(name)
262+
setattr(Config, name, property(get_state))
263+
264+
@contextlib.contextmanager
265+
def set_state(new_val: Optional[str]):
266+
if (new_val is not None and
267+
(type(new_val) is not str or new_val not in enum_values)):
268+
raise ValueError(f"new enum value must be None or in {enum_values}, "
269+
f"got {new_val} of type {type(new_val)}.")
270+
prev_val = getattr(_thread_local_state, name, unset)
222271
setattr(_thread_local_state, name, new_val)
223272
try:
224273
yield
@@ -231,18 +280,25 @@ def set_state(new_val: bool):
231280
set_state.__doc__ = f"Context manager for `{name}` config option.\n\n{help}"
232281
return set_state
233282

283+
234284
_thread_local_state = threading.local()
235285

236286
class Unset: pass
237287
unset = Unset()
238288

239-
class NameSpace(object):
240-
def __init__(self, getter):
241-
self._getter = getter
289+
class NameSpace:
290+
def __init__(self, getter, setter):
291+
# must use super because we override this class's __setattr__, see
292+
# https://docs.python.org/3/reference/datamodel.html#object.__setattr__
293+
super().__setattr__('_getter', getter)
294+
super().__setattr__('_setter', setter)
242295

243296
def __getattr__(self, name):
244297
return self._getter(name)
245298

299+
def __setattr__(self, name, val):
300+
self._setter(name, val)
301+
246302

247303
config = Config()
248304
flags = config
@@ -316,3 +372,32 @@ def __getattr__(self, name):
316372
'computation. Logging is performed with `absl.logging`. When this '
317373
'option is set, the log level is WARNING; otherwise the level is '
318374
'DEBUG.'))
375+
376+
377+
numpy_rank_promotion = config.define_enum_state(
378+
name='jax_numpy_rank_promotion',
379+
enum_values=['allow', 'warn', 'raise'],
380+
default='allow',
381+
help=('Control NumPy-style automatic rank promotion broadcasting '
382+
'("allow", "warn", or "raise").'))
383+
384+
default_dot_precision = config.define_enum_state(
385+
name='jax_default_dot_precision',
386+
enum_values=['bfloat16', 'tensorfloat32', 'float32'],
387+
default=None,
388+
help=('Control the default matmul and conv precision for 32bit inputs.\n\n'
389+
390+
'Some platforms, like TPU, offer configurable precision levels for '
391+
'matrix multiplication and convolution computations, trading off '
392+
'accuracy for speed. The precision can be controlled for each '
393+
'operation; for example, see the :func:`jax.lax.conv_general_dilated` '
394+
'and :func:`jax.lax.dot` docstrings. But it can be useful to control '
395+
'the default behavior obtained when an operation is not given a '
396+
'specific precision.\n\n'
397+
398+
'This option can be used to control the default precision '
399+
'level for computations involved in matrix multiplication and '
400+
'convolution on 32bit inputs. The levels roughly describe the '
401+
"precision at which scalar products are computed. The 'bfloat16' "
402+
"option is the fastest and least precise; 'float32' is similar to "
403+
"full 'float32 precision; 'tensorfloat32' is intermediate.\n\n"))

tests/api_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2393,6 +2393,58 @@ def f(_):
23932393
expected = jnp.arange(1) + 1
23942394
self.assertAllClose(ans, expected)
23952395

2396+
def test_dot_precision_context_manager(self):
2397+
x = jnp.zeros((2, 2))
2398+
2399+
with jax.default_dot_precision(None):
2400+
jnp.dot(x, x) # doesn't crash
2401+
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
2402+
self.assertIn('precision=None', str(jaxpr))
2403+
2404+
with jax.default_dot_precision("bfloat16"):
2405+
jnp.dot(x, x) # doesn't crash
2406+
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
2407+
self.assertIn('precision=DEFAULT', str(jaxpr))
2408+
2409+
with jax.default_dot_precision("tensorfloat32"):
2410+
jnp.dot(x, x) # doesn't crash
2411+
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
2412+
self.assertIn('precision=HIGH', str(jaxpr))
2413+
2414+
with jax.default_dot_precision("float32"):
2415+
jnp.dot(x, x) # doesn't crash
2416+
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
2417+
self.assertIn('precision=HIGHEST', str(jaxpr))
2418+
2419+
dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
2420+
with jax.default_dot_precision("tensorfloat32"):
2421+
dot(x, x) # doesn't crash
2422+
jaxpr = jax.make_jaxpr(dot)(x, x)
2423+
self.assertIn('precision=HIGHEST', str(jaxpr))
2424+
2425+
def test_dot_precision_flag(self):
2426+
x = jnp.zeros((2, 2))
2427+
2428+
prev_val = config._read("jax_default_dot_precision")
2429+
try:
2430+
config.FLAGS.jax_default_dot_precision = "tensorfloat32"
2431+
jnp.dot(x, x) # doesn't crash
2432+
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
2433+
finally:
2434+
config.FLAGS.jax_default_dot_precision = prev_val
2435+
self.assertIn('precision=HIGH', str(jaxpr))
2436+
self.assertEqual(prev_val, config._read("jax_default_dot_precision"))
2437+
2438+
prev_val = config._read("jax_default_dot_precision")
2439+
try:
2440+
config.update('jax_default_dot_precision','tensorfloat32')
2441+
jnp.dot(x, x) # doesn't crash
2442+
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
2443+
finally:
2444+
config.update('jax_default_dot_precision', prev_val)
2445+
self.assertIn('precision=HIGH', str(jaxpr))
2446+
self.assertEqual(prev_val, config._read("jax_default_dot_precision"))
2447+
23962448

23972449
class RematTest(jtu.JaxTestCase):
23982450

0 commit comments

Comments
 (0)