-
I'm also asking this in the JAX repo and a few Discord channels but didn't have an answer yet. fp8 has hardware support only on GPUs with sm >= 89 (Ada), such as RTX 4090 or A100. I've seen people trying to run it in PyTorch (e.g., this script) on older GPUs and getting errors. But JAX can actually run it on older GPUs. I tried to run def f(x, y): return x @ y
a = jnp.ones((3, 3), dtype=jnp.float8_e4m3fn)
print(jax.jit(f).lower(a, a).as_text()) and I can see the dtype is f8E4M3FN in the HLO IR. Then I used `module_0005.jit_f.ir-no-opt.ll` with `dtype=jnp.float32`
`module_0005.jit_f.ir-no-opt.ll` with `dtype=jnp.float8_e4m3fn`
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
XLA falls back to a higher precision. I wouldn't call this fp8 emulation since it does not try to match the corresponding fp8 semantics, but more like "higher precision" fallback. As far as I can tell, this happens through a mix of Gemm Rewriter passes and Float Normalization passes. You can add Some backstory can be found in this discussion and git logs. |
Beta Was this translation helpful? Give feedback.
XLA falls back to a higher precision.
It will upcast operands to fp16 (if supported) and perform the dot in that precision and then downcast back to fp8.
I wouldn't call this fp8 emulation since it does not try to match the corresponding fp8 semantics, but more like "higher precision" fallback.
As far as I can tell, this happens through a mix of Gemm Rewriter passes and Float Normalization passes.
You can add
--xla_dump_hlo_pass_re=.*
toXLA_FLAGS
to see how the IR changes through the compiler passes.Some backstory can be found in this discussion and git logs.