Skip to content

Commit 89768a3

Browse files
committed
add jax_default_matmul_precision flag & context mngr
1 parent 7b4c2e3 commit 89768a3

File tree

6 files changed

+224
-51
lines changed

6 files changed

+224
-51
lines changed

jax/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131

3232
# flake8: noqa: F401
3333
from .config import (config, enable_checks, check_tracer_leaks, checking_leaks,
34-
debug_nans, debug_infs, log_compiles)
34+
debug_nans, debug_infs, log_compiles,
35+
default_matmul_precision, numpy_rank_promotion)
3536
from .api import (
3637
ad, # TODO(phawkins): update users to avoid this.
3738
argnums_partial, # TODO(phawkins): update Haiku to not use this.

jax/_src/lax/lax.py

Lines changed: 55 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -506,8 +506,17 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array:
506506
Precision = xla_client.PrecisionConfig.Precision
507507
Precision.__str__ = lambda precision: precision.name
508508
PrecisionType = Any
509-
PrecisionLike = Union[None, PrecisionType, Tuple[PrecisionType, PrecisionType]]
510-
509+
PrecisionLike = Union[None, str, PrecisionType, Tuple[str, str],
510+
Tuple[PrecisionType, PrecisionType]]
511+
_precision_strings = {
512+
'highest': Precision.HIGHEST,
513+
'float32': Precision.HIGHEST,
514+
'bfloat16_3x': Precision.HIGH,
515+
'tensorfloat32': Precision.HIGH,
516+
'bfloat16': Precision.DEFAULT,
517+
'fastest': Precision.DEFAULT,
518+
None: Precision.DEFAULT,
519+
}
511520

512521
class ConvDimensionNumbers(NamedTuple):
513522
"""Describes batch, spatial, and feature dimensions of a convolution.
@@ -555,42 +564,44 @@ def conv_general_dilated(
555564
rhs_dilation: `None`, or a sequence of `n` integers, giving the
556565
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
557566
is also known as atrous convolution.
558-
dimension_numbers: either `None`, a `ConvDimensionNumbers` object, or
559-
a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string
560-
of length `n+2`.
567+
dimension_numbers: either `None`, a ``ConvDimensionNumbers`` object, or
568+
a 3-tuple ``(lhs_spec, rhs_spec, out_spec)``, where each element is a
569+
string of length `n+2`.
561570
feature_group_count: integer, default 1. See XLA HLO docs.
562571
batch_group_count: integer, default 1. See XLA HLO docs.
563572
precision: Optional. Either ``None``, which means the default precision for
564573
the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
565-
``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
566-
``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
574+
``Precision.HIGH`` or ``Precision.HIGHEST``), a string (e.g. 'highest' or
575+
'fastest', see the ``jax.default_matmul_precision`` context manager), or a
576+
tuple of two ``lax.Precision`` enums or strings indicating precision of
577+
``lhs`` and ``rhs``.
567578
568579
Returns:
569580
An array containing the convolution result.
570581
571-
In the string case of `dimension_numbers`, each character identifies by
582+
In the string case of ``dimension_numbers``, each character identifies by
572583
position:
573584
574-
- the batch dimensions in `lhs`, `rhs`, and the output with the character
585+
- the batch dimensions in ``lhs``, ``rhs``, and the output with the character
575586
'N',
576587
- the feature dimensions in `lhs` and the output with the character 'C',
577588
- the input and output feature dimensions in rhs with the characters 'I'
578589
and 'O' respectively, and
579590
- spatial dimension correspondences between lhs, rhs, and the output using
580591
any distinct characters.
581592
582-
For example, to indicate dimension numbers consistent with the `conv` function
583-
with two spatial dimensions, one could use `('NCHW', 'OIHW', 'NCHW')`. As
584-
another example, to indicate dimension numbers consistent with the TensorFlow
585-
Conv2D operation, one could use `('NHWC', 'HWIO', 'NHWC')`. When using the
586-
latter form of convolution dimension specification, window strides are
587-
associated with spatial dimension character labels according to the order in
588-
which the labels appear in the `rhs_spec` string, so that `window_strides[0]`
589-
is matched with the dimension corresponding to the first character
590-
appearing in rhs_spec that is not `'I'` or `'O'`.
591-
592-
If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')`
593-
(for a 2D convolution).
593+
For example, to indicate dimension numbers consistent with the ``conv``
594+
function with two spatial dimensions, one could use ``('NCHW', 'OIHW',
595+
'NCHW')``. As another example, to indicate dimension numbers consistent with
596+
the TensorFlow Conv2D operation, one could use ``('NHWC', 'HWIO', 'NHWC')``.
597+
When using the latter form of convolution dimension specification, window
598+
strides are associated with spatial dimension character labels according to
599+
the order in which the labels appear in the ``rhs_spec`` string, so that
600+
``window_strides[0]`` is matched with the dimension corresponding to the first
601+
character appearing in rhs_spec that is not ``'I'`` or ``'O'``.
602+
603+
If ``dimension_numbers`` is ``None``, the default is ``('NCHW', 'OIHW',
604+
'NCHW')`` (for a 2D convolution).
594605
"""
595606
dnums = conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
596607
if lhs_dilation is None:
@@ -6394,16 +6405,31 @@ def remaining(original, *removed_lists):
63946405

63956406
def _canonicalize_precision(precision):
63966407
if precision is None:
6397-
return None
6398-
if isinstance(precision, Precision) or (
6399-
isinstance(precision, tuple)
6400-
and len(precision) == 2
6401-
and all(isinstance(p, Precision) for p in precision)
6402-
):
6408+
if config.jax_default_matmul_precision is None:
6409+
return None
6410+
try:
6411+
return _precision_strings[config.jax_default_matmul_precision]
6412+
except KeyError:
6413+
raise ValueError(
6414+
"jax_default_matmul_precision flag must be set to None or a value in "
6415+
f"{_precision_strings}, but got {config.jax_default_matmul_precision}"
6416+
) from None
6417+
elif isinstance(precision, str) and precision in _precision_strings:
6418+
return _precision_strings.get(precision)
6419+
elif isinstance(precision, Precision):
64036420
return precision
6421+
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
6422+
all(isinstance(p, Precision) for p in precision)):
6423+
return precision
6424+
elif (isinstance(precision, (list, tuple)) and len(precision) == 2 and
6425+
all(isinstance(s, str) for s in precision)):
6426+
s1, s2 = precision
6427+
return (_canonicalize_precision(s1), _canonicalize_precision(s2))
64046428
else:
6405-
raise ValueError("Precision argument must be None, a lax.Precision value "
6406-
f"or a tuple of two lax.Precision values; got {precision}")
6429+
raise ValueError(
6430+
f"Precision argument must be None, a string in {_precision_strings}, "
6431+
"a lax.Precision value or a tuple of two lax.Precision values or "
6432+
f"strings; got {precision}.")
64076433

64086434

64096435
def conv_dimension_numbers(lhs_shape, rhs_shape, dimension_numbers

jax/_src/numpy/lax_numpy.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import collections
2929
import collections.abc
3030
import operator
31-
import os
3231
import types
3332
from typing import Any, Sequence, FrozenSet, Optional, Tuple, Union, cast
3433
from textwrap import dedent as _dedent
@@ -45,7 +44,7 @@
4544
from jax import dtypes
4645
from jax import errors
4746
from jax.core import UnshapedArray, ShapedArray, ConcreteArray, canonicalize_shape
48-
from jax.config import flags, config
47+
from jax.config import config
4948
from jax.interpreters.xla import DeviceArray, _DeviceArray, _CppDeviceArray
5049
from jax.interpreters.masking import Poly
5150
from jax import lax
@@ -55,14 +54,6 @@
5554
canonicalize_axis as _canonicalize_axis, maybe_named_axis)
5655
from jax.tree_util import tree_leaves, tree_flatten, tree_map
5756

58-
FLAGS = flags.FLAGS
59-
flags.DEFINE_enum(
60-
'jax_numpy_rank_promotion', os.getenv('JAX_NUMPY_RANK_PROMOTION', 'allow'),
61-
enum_values=['allow', 'warn', 'raise'],
62-
help=
63-
'Control NumPy-style automatic rank promotion broadcasting '
64-
'("allow", "warn", or "raise").')
65-
6657
newaxis = None
6758

6859
# Common docstring additions:
@@ -247,20 +238,20 @@ def _promote_shapes(fun_name, *args):
247238
if not nonscalar_ranks or len(set(nonscalar_ranks)) == 1:
248239
return args
249240
else:
250-
if FLAGS.jax_numpy_rank_promotion != "allow":
241+
if config.jax_numpy_rank_promotion != "allow":
251242
_rank_promotion_warning_or_error(fun_name, shapes)
252243
result_rank = len(lax.broadcast_shapes(*shapes))
253244
return [broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp)
254245
for arg, shp in zip(args, shapes)]
255246

256247
def _rank_promotion_warning_or_error(fun_name, shapes):
257-
if FLAGS.jax_numpy_rank_promotion == "warn":
248+
if config.jax_numpy_rank_promotion == "warn":
258249
msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
259250
"Set the jax_numpy_rank_promotion config option to 'allow' to "
260251
"disable this warning; for more information, see "
261252
"https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.")
262253
warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes))))
263-
elif FLAGS.jax_numpy_rank_promotion == "raise":
254+
elif config.jax_numpy_rank_promotion == "raise":
264255
msg = ("Operands could not be broadcast together for {} on shapes {} "
265256
"and with the config option jax_numpy_rank_promotion='raise'. "
266257
"For more information, see "

jax/config.py

Lines changed: 86 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
import os
1818
import sys
1919
import threading
20+
from typing import List, Callable, Optional
2021

2122
from jax import lib
22-
from typing import Callable, Optional
2323

2424
def bool_env(varname: str, default: bool) -> bool:
2525
"""Read an environment variable and interpret it as a boolean.
@@ -52,7 +52,7 @@ class Config:
5252
def __init__(self):
5353
self.values = {}
5454
self.meta = {}
55-
self.FLAGS = NameSpace(self.read)
55+
self.FLAGS = NameSpace(self.read, self.update)
5656
self.use_absl = False
5757
self._contextmanager_flags = set()
5858

@@ -255,18 +255,70 @@ def set_state(new_val: bool):
255255
set_state.__doc__ = f"Context manager for `{name}` config option.\n\n{help}"
256256
return set_state
257257

258+
def define_enum_state(self, name: str, enum_values: List[str],
259+
default: Optional[str], help: str):
260+
"""Set up thread-local state and return a contextmanager for managing it.
261+
Args:
262+
name: string, converted to lowercase to define the name of the config
263+
option (and absl flag). It is converted to uppercase to define the
264+
corresponding shell environment variable.
265+
enum_values: list of strings representing the possible values for the
266+
option.
267+
default: optional string, default value.
268+
help: string, used to populate the flag help information as well as the
269+
docstring of the returned context manager.
270+
Returns:
271+
A contextmanager to control the thread-local state value.
272+
See docstring for ``define_bool_state``.
273+
"""
274+
name = name.lower()
275+
self.DEFINE_enum(name, os.getenv(name.upper(), default),
276+
enum_values=enum_values, help=help)
277+
self._contextmanager_flags.add(name)
278+
279+
def get_state(self):
280+
val = getattr(_thread_local_state, name, unset)
281+
return val if val is not unset else self._read(name)
282+
setattr(Config, name, property(get_state))
283+
284+
@contextlib.contextmanager
285+
def set_state(new_val: Optional[str]):
286+
if (new_val is not None and
287+
(type(new_val) is not str or new_val not in enum_values)):
288+
raise ValueError(f"new enum value must be None or in {enum_values}, "
289+
f"got {new_val} of type {type(new_val)}.")
290+
prev_val = getattr(_thread_local_state, name, unset)
291+
setattr(_thread_local_state, name, new_val)
292+
try:
293+
yield
294+
finally:
295+
if prev_val is unset:
296+
delattr(_thread_local_state, name)
297+
else:
298+
setattr(_thread_local_state, name, prev_val)
299+
set_state.__name__ = name[4:] if name.startswith('jax_') else name
300+
set_state.__doc__ = f"Context manager for `{name}` config option.\n\n{help}"
301+
return set_state
302+
303+
258304
_thread_local_state = threading.local()
259305

260306
class Unset: pass
261307
unset = Unset()
262308

263-
class NameSpace(object):
264-
def __init__(self, getter):
265-
self._getter = getter
309+
class NameSpace:
310+
def __init__(self, getter, setter):
311+
# must use super because we override this class's __setattr__, see
312+
# https://docs.python.org/3/reference/datamodel.html#object.__setattr__
313+
super().__setattr__('_getter', getter)
314+
super().__setattr__('_setter', setter)
266315

267316
def __getattr__(self, name):
268317
return self._getter(name)
269318

319+
def __setattr__(self, name, val):
320+
self._setter(name, val)
321+
270322

271323
config = Config()
272324
flags = config
@@ -357,3 +409,32 @@ def _update_x64_thread_local(val):
357409
config._contextmanager_flags.remove("jax_enable_x64")
358410

359411
Config.x64_enabled = Config.jax_enable_x64 # type: ignore
412+
413+
414+
numpy_rank_promotion = config.define_enum_state(
415+
name='jax_numpy_rank_promotion',
416+
enum_values=['allow', 'warn', 'raise'],
417+
default='allow',
418+
help=('Control NumPy-style automatic rank promotion broadcasting '
419+
'("allow", "warn", or "raise").'))
420+
421+
default_matmul_precision = config.define_enum_state(
422+
name='jax_default_matmul_precision',
423+
enum_values=['bfloat16', 'tensorfloat32', 'float32'],
424+
default=None,
425+
help=('Control the default matmul and conv precision for 32bit inputs.\n\n'
426+
427+
'Some platforms, like TPU, offer configurable precision levels for '
428+
'matrix multiplication and convolution computations, trading off '
429+
'accuracy for speed. The precision can be controlled for each '
430+
'operation; for example, see the :func:`jax.lax.conv_general_dilated` '
431+
'and :func:`jax.lax.dot` docstrings. But it can be useful to control '
432+
'the default behavior obtained when an operation is not given a '
433+
'specific precision.\n\n'
434+
435+
'This option can be used to control the default precision '
436+
'level for computations involved in matrix multiplication and '
437+
'convolution on 32bit inputs. The levels roughly describe the '
438+
"precision at which scalar products are computed. The 'bfloat16' "
439+
"option is the fastest and least precise; 'float32' is similar to "
440+
"full float32 precision; 'tensorfloat32' is intermediate.\n\n"))

tests/api_test.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import weakref
2626
import functools
2727
import itertools as it
28+
import operator as op
2829

2930
from absl import logging
3031
from absl.testing import absltest, parameterized
@@ -2399,6 +2400,58 @@ def test_large_python_int_to_float(self):
23992400
out = lax.convert_element_type(2 ** 100, jnp.float32) # doesn't crash
24002401
self.assertArraysEqual(out, np.float32(2 ** 100))
24012402

2403+
def test_dot_precision_context_manager(self):
2404+
x = jnp.zeros((2, 2))
2405+
2406+
with jax.default_matmul_precision(None):
2407+
jnp.dot(x, x) # doesn't crash
2408+
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
2409+
self.assertIn('precision=None', str(jaxpr))
2410+
2411+
with jax.default_matmul_precision("bfloat16"):
2412+
x @ x # doesn't crash
2413+
jaxpr = jax.make_jaxpr(op.matmul)(x, x)
2414+
self.assertIn('precision=DEFAULT', str(jaxpr))
2415+
2416+
with jax.default_matmul_precision("tensorfloat32"):
2417+
jnp.dot(x, x) # doesn't crash
2418+
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
2419+
self.assertIn('precision=HIGH\n', str(jaxpr))
2420+
2421+
with jax.default_matmul_precision("float32"):
2422+
jnp.dot(x, x) # doesn't crash
2423+
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
2424+
self.assertIn('precision=HIGHEST', str(jaxpr))
2425+
2426+
dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
2427+
with jax.default_matmul_precision("tensorfloat32"):
2428+
dot(x, x) # doesn't crash
2429+
jaxpr = jax.make_jaxpr(dot)(x, x)
2430+
self.assertIn('precision=HIGHEST', str(jaxpr))
2431+
2432+
def test_dot_precision_flag(self):
2433+
x = jnp.zeros((2, 2))
2434+
2435+
prev_val = config._read("jax_default_matmul_precision")
2436+
try:
2437+
config.FLAGS.jax_default_matmul_precision = "tensorfloat32"
2438+
jnp.dot(x, x) # doesn't crash
2439+
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
2440+
finally:
2441+
config.FLAGS.jax_default_matmul_precision = prev_val
2442+
self.assertIn('precision=HIGH', str(jaxpr))
2443+
self.assertEqual(prev_val, config._read("jax_default_matmul_precision"))
2444+
2445+
prev_val = config._read("jax_default_matmul_precision")
2446+
try:
2447+
config.update('jax_default_matmul_precision','tensorfloat32')
2448+
jnp.dot(x, x) # doesn't crash
2449+
jaxpr = jax.make_jaxpr(jnp.dot)(x, x)
2450+
finally:
2451+
config.update('jax_default_matmul_precision', prev_val)
2452+
self.assertIn('precision=HIGH', str(jaxpr))
2453+
self.assertEqual(prev_val, config._read("jax_default_matmul_precision"))
2454+
24022455

24032456
class RematTest(jtu.JaxTestCase):
24042457

0 commit comments

Comments
 (0)