Skip to content

Improved docs for (default) matmul precision #10413

Open
@sanchit-gandhi

Description

@sanchit-gandhi
  1. It is somewhat unintuitive that the default matmul precision is bfloat16 on TPU, especially for users coming from PyTorch/GPU where the default precision is float32. Information regarding the default matrix multiplication precision on TPUs is extremely difficult to find. There is a short section on the README.md within the cloud TPU Colab folder of the JAX repo: https://github.com/google/jax/tree/main/cloud_tpu_colabs#bfloat16-dtype However, this is somewhat unclear, as it references 'MXUs' without any explanation of what this abbreviation means, and only highlights how the default precision can be changed manually on a op-by-op basis by setting precision=jax.lax.Precision.XXX. This gives the impression that in order to change the TPU precision to float32, one must insert the key-word argument precision=jax.lax.Precision.HIGHEST for every jax.numpy operation in one's script.

  2. It is difficult to find how the default precision can be changed. Performing matmul operations in the default bfloat16 precision can lead to undesirable results. At Hugging Face, we're constantly running into problems with the default fast-speed low precision TPU default, as shown here for example: Diverging PT-Flax Wav2Vec2 Hidden-States huggingface/transformers#15754
    In the case of changing the default matmul precision, the docs do make mention to the default matmul precision context manager: https://jax.readthedocs.io/en/latest/_autosummary/jax.default_matmul_precision.html However, they do not explicitly state how one can use this context manager to change the default matmul precision (for instance with an example). It's hard to know from the docs that you have to write your code under the context manager as follows:

with jax.default_matmul_precision('float32'):   # or 'bfloat16' for lowest
  ... = foo(...)

The docs also brush over three additional methods for changing the default matmul precision, highlighted brilliantly in this PR: #6143 (comment) These three methods require no change to one's actual script, just the inclusion of a shell/command line flag or a JAX config change, and are arguably much easier to use and less obtrusive.

It would be great if the default matmul precisions for CPU/GPU/TPU were documented, along with what bfloat16, tensorfloat16, float32 precision actually mean for matmul precision in terms of number of passes. It would also be super helpful if all four methods for manipulating the default precision were added to the docs with short examples on how to use them, as done in the aforementioned PR.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions