Skip to content

Why can JAX run fp8 on Nvidia GPUs with sm < 89? #23124

Answered by Angelogeb
woct0rdho asked this question in Q&A
Discussion options

You must be logged in to vote

XLA falls back to a higher precision.
It will upcast operands to fp16 (if supported) and perform the dot in that precision and then downcast back to fp8.

I wouldn't call this fp8 emulation since it does not try to match the corresponding fp8 semantics, but more like "higher precision" fallback.

As far as I can tell, this happens through a mix of Gemm Rewriter passes and Float Normalization passes.

You can add --xla_dump_hlo_pass_re=.* to XLA_FLAGS to see how the IR changes through the compiler passes.

Some backstory can be found in this discussion and git logs.

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@woct0rdho
Comment options

Answer selected by woct0rdho
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants