@@ -506,8 +506,17 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array:
506
506
Precision = xla_client .PrecisionConfig .Precision
507
507
Precision .__str__ = lambda precision : precision .name
508
508
PrecisionType = Any
509
- PrecisionLike = Union [None , PrecisionType , Tuple [PrecisionType , PrecisionType ]]
510
-
509
+ PrecisionLike = Union [None , str , PrecisionType , Tuple [str , str ],
510
+ Tuple [PrecisionType , PrecisionType ]]
511
+ _precision_strings = {
512
+ 'highest' : Precision .HIGHEST ,
513
+ 'float32' : Precision .HIGHEST ,
514
+ 'bfloat16_3x' : Precision .HIGH ,
515
+ 'tensorfloat32' : Precision .HIGH ,
516
+ 'bfloat16' : Precision .DEFAULT ,
517
+ 'fastest' : Precision .DEFAULT ,
518
+ None : Precision .DEFAULT ,
519
+ }
511
520
512
521
class ConvDimensionNumbers (NamedTuple ):
513
522
"""Describes batch, spatial, and feature dimensions of a convolution.
@@ -555,42 +564,44 @@ def conv_general_dilated(
555
564
rhs_dilation: `None`, or a sequence of `n` integers, giving the
556
565
dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
557
566
is also known as atrous convolution.
558
- dimension_numbers: either `None`, a `ConvDimensionNumbers` object, or
559
- a 3-tuple `(lhs_spec, rhs_spec, out_spec)`, where each element is a string
560
- of length `n+2`.
567
+ dimension_numbers: either `None`, a `` ConvDimensionNumbers` ` object, or
568
+ a 3-tuple `` (lhs_spec, rhs_spec, out_spec)`` , where each element is a
569
+ string of length `n+2`.
561
570
feature_group_count: integer, default 1. See XLA HLO docs.
562
571
batch_group_count: integer, default 1. See XLA HLO docs.
563
572
precision: Optional. Either ``None``, which means the default precision for
564
573
the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
565
- ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
566
- ``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
574
+ ``Precision.HIGH`` or ``Precision.HIGHEST``), a string (e.g. 'highest' or
575
+ 'fastest', see the ``jax.default_matmul_precision`` context manager), or a
576
+ tuple of two ``lax.Precision`` enums or strings indicating precision of
577
+ ``lhs`` and ``rhs``.
567
578
568
579
Returns:
569
580
An array containing the convolution result.
570
581
571
- In the string case of `dimension_numbers`, each character identifies by
582
+ In the string case of `` dimension_numbers` `, each character identifies by
572
583
position:
573
584
574
- - the batch dimensions in `lhs`, `rhs`, and the output with the character
585
+ - the batch dimensions in `` lhs`` , `` rhs` `, and the output with the character
575
586
'N',
576
587
- the feature dimensions in `lhs` and the output with the character 'C',
577
588
- the input and output feature dimensions in rhs with the characters 'I'
578
589
and 'O' respectively, and
579
590
- spatial dimension correspondences between lhs, rhs, and the output using
580
591
any distinct characters.
581
592
582
- For example, to indicate dimension numbers consistent with the `conv` function
583
- with two spatial dimensions, one could use `('NCHW', 'OIHW', 'NCHW')`. As
584
- another example, to indicate dimension numbers consistent with the TensorFlow
585
- Conv2D operation, one could use `('NHWC', 'HWIO', 'NHWC')`. When using the
586
- latter form of convolution dimension specification, window strides are
587
- associated with spatial dimension character labels according to the order in
588
- which the labels appear in the `rhs_spec` string, so that `window_strides[0]`
589
- is matched with the dimension corresponding to the first character
590
- appearing in rhs_spec that is not `'I'` or `'O'`.
591
-
592
- If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')`
593
- (for a 2D convolution).
593
+ For example, to indicate dimension numbers consistent with the `` conv``
594
+ function with two spatial dimensions, one could use `` ('NCHW', 'OIHW',
595
+ 'NCHW')``. As another example, to indicate dimension numbers consistent with
596
+ the TensorFlow Conv2D operation, one could use `` ('NHWC', 'HWIO', 'NHWC')``.
597
+ When using the latter form of convolution dimension specification, window
598
+ strides are associated with spatial dimension character labels according to
599
+ the order in which the labels appear in the `` rhs_spec`` string, so that
600
+ ``window_strides[0]`` is matched with the dimension corresponding to the first
601
+ character appearing in rhs_spec that is not `` 'I'`` or `` 'O'` `.
602
+
603
+ If `` dimension_numbers`` is `` None`` , the default is `` ('NCHW', 'OIHW',
604
+ 'NCHW')`` (for a 2D convolution).
594
605
"""
595
606
dnums = conv_dimension_numbers (lhs .shape , rhs .shape , dimension_numbers )
596
607
if lhs_dilation is None :
@@ -6394,16 +6405,31 @@ def remaining(original, *removed_lists):
6394
6405
6395
6406
def _canonicalize_precision (precision ):
6396
6407
if precision is None :
6397
- return None
6398
- if isinstance (precision , Precision ) or (
6399
- isinstance (precision , tuple )
6400
- and len (precision ) == 2
6401
- and all (isinstance (p , Precision ) for p in precision )
6402
- ):
6408
+ if config .jax_default_matmul_precision is None :
6409
+ return None
6410
+ try :
6411
+ return _precision_strings [config .jax_default_matmul_precision ]
6412
+ except KeyError :
6413
+ raise ValueError (
6414
+ "jax_default_matmul_precision flag must be set to None or a value in "
6415
+ f"{ _precision_strings } , but got { config .jax_default_matmul_precision } "
6416
+ ) from None
6417
+ elif isinstance (precision , str ) and precision in _precision_strings :
6418
+ return _precision_strings .get (precision )
6419
+ elif isinstance (precision , Precision ):
6403
6420
return precision
6421
+ elif (isinstance (precision , (list , tuple )) and len (precision ) == 2 and
6422
+ all (isinstance (p , Precision ) for p in precision )):
6423
+ return precision
6424
+ elif (isinstance (precision , (list , tuple )) and len (precision ) == 2 and
6425
+ all (isinstance (s , str ) for s in precision )):
6426
+ s1 , s2 = precision
6427
+ return (_canonicalize_precision (s1 ), _canonicalize_precision (s2 ))
6404
6428
else :
6405
- raise ValueError ("Precision argument must be None, a lax.Precision value "
6406
- f"or a tuple of two lax.Precision values; got { precision } " )
6429
+ raise ValueError (
6430
+ f"Precision argument must be None, a string in { _precision_strings } , "
6431
+ "a lax.Precision value or a tuple of two lax.Precision values or "
6432
+ f"strings; got { precision } ." )
6407
6433
6408
6434
6409
6435
def conv_dimension_numbers (lhs_shape , rhs_shape , dimension_numbers
0 commit comments