Claim: FP8 matmuls on H100s are less precise than on MI300x.
Really?#
Not really. ROCm kernels may be dumber than their CUDA counterparts. However,
- Most fp8 users on AMD have slightly more unique finite numbers per byte than on hopper
- fp8 matmuls will accumulate outputs more precisely on AMD than on hopper.
Insofar as a single tensor/matrix core op is concerned, the initial statement is true.
Cardinality#
Please see here and here and especially here for reference.
dtype | brand | unique finites | outliers | $min_p\ 2^p$ | $max_p\ 2^p$ |
---|---|---|---|---|---|
e5m2 | nvidia | 246 | ±0, ±inf, ±nan x3 | $2^{-16}$ | $2^{15}$ |
e4m3fn | nvidia | 253 | ±0, ±nan | $2^{-9}$ | $2^8$ |
e4m3fnuz | amd | 255 | nan | $2^{-10}$ | $2^7$ |
e5m2fnuz | amd | 255 | nan | $2^{-17}$ | $2^{15}$ |
In principle, AMD supports both pure IEEE && fnuz variants, but in practice
- fnuz is preferred by end-users, e.g. torchao hardcodes mi300x -> fnuz
- ROCm ecosystem is poorer for pure e4m3/e5m2 variants.
For Hopper, nvidia follows the spec defined in FP8 Formats for Deep Learning.
So, “in general”, AMD fp8 formats are marginally more expressive.
Are you sure?#
No. I lack an intuition pump for what bits “should matter”. I think picking 1.875 over 1.75 as the max normal’s mantissa is an unlikely selling point.
Some people have claimed marginally better accuracy on real world workloads.
Accumulation#
Deepseek#
Remember Deepseek-V3? They complained,
we observe that the accumulation precision of FP8 GEMM on NVIDIA H800 GPUs is limited to retaining around 14 bits, which is significantly lower than FP32 accumulation precision.
SageAttention2 explains this more accurately:
After narrowing down the problem, we find that the accumulator for the
mma(f32f8f8f32)
instruction on the Ada and Hopper architecture is actually FP22, specifically with 1 sign bit, 8 exponent bits, and 13 mantissa bits.Specifically, for
mma(f32f8f8f32)
instruction $C = AB + D$, where $A$, $B$ are FP8 matrices and $C, D$ are FP32 matrices, we initialize the $A, B$ to zero and vary $D$ to test the data type of the accumulator […] when $D$ is initialized with more than 13 mantissa bits, the error of $C$ corresponds to the difference between the results ofmma(f32f16f16f32)
andmma(f32f8f8f32)
.
It doesn’t matter what the inputs $A,B$ even are. So long as their dtype is fp8, the 10 least-significant mantissa bits of the output $C$ will always be zero.
This is easy to test, given you have basic knowledge of floating point formats && CUDA. You essentially need to execute something like,
using namespace nvcuda;
using e4m3 = //...;
constexpr uint32_t bits = 0b0'01111111'11111111111111111111111; // 2 - epsilon
__global__ void fp8_mma_kernel(
const e4m3* A, const e4m3* B, float* D
) {
wmma::fragment<wmma::matrix_a, M, N, K, e4m3, wmma::row_major> a_frag;
wmma::fragment<wmma::matrix_b, M, N, K, e4m3, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, M, N, K, float> c_frag;
wmma::fragment<wmma::accumulator, M, N, K, float> d_frag;
wmma::fill_fragment(a_frag, .0f);
wmma::fill_fragment(b_frag, .0f);
wmma::fill_fragment(c_frag, std::bit_cast<float>(bits));
wmma::mma_sync(d_frag, a_frag, b_frag, c_frag);
wmma::store_matrix_sync(D, d_frag, N, wmma::mem_row_major);
}
(except with some inline mma
because wmma does not support fp8)
So, SageAttention2 shows the result of $D = AB+C$ when fed 0 $A,B$ and arbitrary $C$, will be of FP22 precision. But, does that mean:
- the mantissa of $C$ is dropped immediately when read, before any FMA?
- the lowest bits are preserved for contributions from $AB$?
- the bits don’t exist for the accumulator $D$ to begin with?
- the higher bits that should exist from carried/combined $AB+C$ are also lost?
Slightly different test#
To test this, I expanded on the SageAttention2 strategy a little bit. Because their first suggested test did not involve $A,B$ calculations, I wanted to know if a single FMA in-place (using a matrix/tensor core) would start dropping mantissa bits.
That is, for each dot product
$$D[i,j] := A[i,0]*B[0,j] + ... + A[i,31]*B[31,j] + C[i,j]$$I simply set all A[:,:31]
and B[:31,:]
to 0, and $D=C$, which simplifies the above to C[i,j] += A[i,31]*B[31,j]
or
To test when mantissa bits start dropping, I ensure
- $a*b = 2^m*2^n = 2^p$; see table above for min/max range of $p$
- $c = 2^{p+k}$, $k>=1$ which guarantees that the update to $c$ will either
- flip a single mantissa bit to be on, or
- do nothing (truncated to zero)
When I execute that test on a 4090, I obtain unsurprising results:
$ ./fp8_4090
C[0:4,0:4]
131088 131072 131072 131072
131072 131072 131072 131072
131072 131072 131072 131072
131072 131072 131072 131072
p= -7 expected=7.812500000e-03 actual=0.000000000e+00 (lost)
p= -6 expected=1.562500000e-02 actual=0.000000000e+00 (lost)
p= -5 expected=3.125000000e-02 actual=0.000000000e+00 (lost)
p= -4 expected=6.250000000e-02 actual=0.000000000e+00 (lost)
p= -3 expected=1.250000000e-01 actual=0.000000000e+00 (lost)
p= -2 expected=2.500000000e-01 actual=0.000000000e+00 (lost)
p= -1 expected=5.000000000e-01 actual=0.000000000e+00 (lost)
p= 0 expected=1.000000000e+00 actual=0.000000000e+00 (lost)
p= 1 expected=2.000000000e+00 actual=0.000000000e+00 (lost)
p= 2 expected=4.000000000e+00 actual=0.000000000e+00 (lost)
p= 3 expected=8.000000000e+00 actual=0.000000000e+00 (lost)
p= 4 expected=1.600000000e+01 actual=1.600000000e+01 (kept)
p= 5 expected=3.200000000e+01 actual=3.200000000e+01 (kept)
p= 6 expected=6.400000000e+01 actual=6.400000000e+01 (kept)
p= 7 expected=1.280000000e+02 actual=1.280000000e+02 (kept)
p= 8 expected=2.560000000e+02 actual=2.560000000e+02 (kept)
p= 9 expected=5.120000000e+02 actual=5.120000000e+02 (kept)
p= 10 expected=1.024000000e+03 actual=1.024000000e+03 (kept)
p= 11 expected=2.048000000e+03 actual=2.048000000e+03 (kept)
p= 12 expected=4.096000000e+03 actual=4.096000000e+03 (kept)
p= 13 expected=8.192000000e+03 actual=8.192000000e+03 (kept)
p= 14 expected=1.638400000e+04 actual=1.638400000e+04 (kept)
p= 15 expected=3.276800000e+04 actual=3.276800000e+04 (kept)
p= 16 expected=6.553600000e+04 actual=6.553600000e+04 (kept)
main:~
$ ./fp8_4090 | grep kept | wc
13 65 832
The mantissa bits are only retained iff $k<14$. That indicates the result of $a*b+c$ is truncated to FP22.
Code for the test can be found here
Code results#
Recently, it was alleged by AMD devrel that mi300x accumulators for fp8 matmuls have superior precision, compared to the h100.
As unfortunate as it may be for nvidia shareholders, the statement is true:
root@b908ffc12b49:/# ./fp8_mi300x | tee a
C[0:4,0:4]
32784 32768 32768 32768
32768 32768 32768 32768
32768 32768 32768 32768
32768 32768 32768 32768
p= -9 expected=1.953125000e-03 actual=0.000000000e+00 (lost)
p= -8 expected=3.906250000e-03 actual=3.906250000e-03 (kept)
p= -7 expected=7.812500000e-03 actual=7.812500000e-03 (kept)
p= -6 expected=1.562500000e-02 actual=1.562500000e-02 (kept)
p= -5 expected=3.125000000e-02 actual=3.125000000e-02 (kept)
p= -4 expected=6.250000000e-02 actual=6.250000000e-02 (kept)
p= -3 expected=1.250000000e-01 actual=1.250000000e-01 (kept)
p= -2 expected=2.500000000e-01 actual=2.500000000e-01 (kept)
p= -1 expected=5.000000000e-01 actual=5.000000000e-01 (kept)
p= 0 expected=1.000000000e+00 actual=1.000000000e+00 (kept)
p= 1 expected=2.000000000e+00 actual=2.000000000e+00 (kept)
p= 2 expected=4.000000000e+00 actual=4.000000000e+00 (kept)
p= 3 expected=8.000000000e+00 actual=8.000000000e+00 (kept)
p= 4 expected=1.600000000e+01 actual=1.600000000e+01 (kept)
p= 5 expected=3.200000000e+01 actual=3.200000000e+01 (kept)
p= 6 expected=6.400000000e+01 actual=6.400000000e+01 (kept)
p= 7 expected=1.280000000e+02 actual=1.280000000e+02 (kept)
p= 8 expected=2.560000000e+02 actual=2.560000000e+02 (kept)
p= 9 expected=5.120000000e+02 actual=5.120000000e+02 (kept)
p= 10 expected=1.024000000e+03 actual=1.024000000e+03 (kept)
p= 11 expected=2.048000000e+03 actual=2.048000000e+03 (kept)
p= 12 expected=4.096000000e+03 actual=4.096000000e+03 (kept)
p= 13 expected=8.192000000e+03 actual=8.192000000e+03 (kept)
p= 14 expected=1.638400000e+04 actual=1.638400000e+04 (kept)
root@b908ffc12b49:/# grep kept a | wc
23 115 1472
The fp32 accumulator on an mi300x (executing e4m3_fnuz mfma instructions with rocWMMA) retains all 23 mantissa bits. This was tested on rocm/pytorch-nightly
via Runpod.
But is that really conclusive?#
Not really. I’m not intelligent enough to fully understand PMC7959640. My tests only indicate a clear advantage for AMD’s accumulators on a toy problem.
For example, my tests obviously cannot determine if any positions in the inputs are “favored”. Consider:
- Are all values in a single 8x4x32 fp8 problem treated equally, for each cycle in a H100’s Tensor Core?
- Is the numerical behavior of a tensor core identical each cycle, for a full
QMMA
? - if a single
mma.sync
compiles down to severalQMMA
s, is it possible that one has better/worse accuracy?
Note: the above are mere hypotheticals. Based on my limited testing, I have seen no reason to expect favored positions to exist.