Skip to main content

Low-Precision Attention May Suffer from Biased Rounding Errors

·2336 words
Table of Contents

This is a gemini-2.5-flash translation of a Chinese article.

It has NOT been vetted for errors. You should have the original article open in a parallel tab at all times.

By Su Jianlin | 2025-10-27 | 493 readers

Recently, I came across the paper 《Why Low-Precision Transformer Training Fails: An Analysis on Flash Attention》 on arXiv. The experimental phenomena described in it align very well with some issues we encountered when training Kimi K2, such as problems starting from the second Attention layer. The paper attributes these issues to intrinsic biased errors in low-precision Attention, an analysis perspective that quite surprised me, so I read it with great interest.

However, the paper’s description seemed somewhat difficult to grasp—partly because I am not very familiar with low-precision computations. In any case, after consulting the authors multiple times, I finally managed to understand the paper, and so I am documenting my understanding here for everyone’s reference.

Brief Summary of Conclusions
#

It’s worth noting that while the paper’s title explicitly mentions “Flash Attention,” according to its description, the same problems still arise even if block_size is set to the entire training sequence length. Thus, Flash Attention’s block-wise computation is not the root cause of the issue. Therefore, we can simplify the analysis by considering a naive low-precision Attention implementation.

For simplicity, we analyze single-head Attention. Let $\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}\in\mathbb{R}^{n\times d}$. Denote $\boldsymbol{S} = \boldsymbol{Q}\boldsymbol{K}^{\top}$. The bold $\boldsymbol{1}$ refers to an $n\times 1$ matrix of all ones, and $\boldsymbol{S}_{\max}$ refers to an $n\times 1$ matrix obtained by taking the maximum value of each row of $\boldsymbol{S}$. Then

$$ \boldsymbol{O} = \frac{\exp(\boldsymbol{S})\boldsymbol{V}}{\exp(\boldsymbol{S})\boldsymbol{1}} = \frac{\exp(\boldsymbol{S} - \boldsymbol{S}_{\max})\boldsymbol{V}}{\exp(\boldsymbol{S}- \boldsymbol{S}_{\max})\boldsymbol{1}} $$

We define $\bar{\boldsymbol{P}} = \exp(\boldsymbol{S} - \boldsymbol{S}_{\max})$. The critical computation for Attention is the matrix multiplication $\bar{\boldsymbol{P}}\boldsymbol{V}$, which is typically performed in BF16 precision. The paper concludes that: In low-precision computation, the $\bar{\boldsymbol{P}}\boldsymbol{V}$ step exhibits a biased rounding error. That is, over the long-term average, the expected difference between the low-precision computation of $\bar{\boldsymbol{P}}\boldsymbol{V}$ and the accurate value is not zero.

Consequently, deviations between different training steps may continuously accumulate, leading to issues like MaxLogit explosion, Loss Spikes, and ultimately training collapse. Of course, strictly speaking, this is just one possible mechanism for problems like MaxLogit explosion, not necessarily the sole one. But even so, it’s worth our study and consideration.

Round-to-Even
#

To understand the paper’s conclusions, let’s first review some basic knowledge about rounding errors. The reason for writing this section, as stated at the beginning, is that I am not familiar with low-precision computations. Therefore, this section is entirely for my own foundational review; readers who are already familiar with this topic can skip it.

We know that the common rounding method is “round half up”: in base 10, when rounding a positive one-decimal-place number, digits 0-4 are rounded down to 0, producing errors of $0,-0.1,-0.2,-0.3,-0.4$; digits 5-9 are rounded up to 10, producing errors of $0.5,0.4,0.3,0.2,0.1$. You might have noticed that the average of these errors is not 0, but 0.05. This means that “round half up” tends to magnify the original number on average, causing a positive bias.

Naturally, the relative bias decreases as the number of discarded digits increases; for example, if a two-decimal-place number is rounded to zero decimal places, the average error is 0.005. However, this positive bias in round half up always exists, just in varying magnitudes. The root of the problem lies in the halfway points. For example, 0.51 and 0.49, when rounded up/down respectively, their errors cancel out. But for 0.50, whether it’s rounded up or down, there’s no other number to cancel out its error.

To eliminate this bias, IEEE 754 introduced the “Round-to-Even” principle. It stipulates that for halfway cases, rounding should occur towards the nearest even digit. For instance, 2.5 rounded to an integer becomes 2, while 3.5 rounded to an integer becomes 4. This way, the “5” has an equal chance of producing $\pm 5$ errors, making the average error zero and thus eliminating the bias.

Returning to the computer domain. We know computers use binary, with only 0 and 1. In this context, “1” plays the role that “5” does in decimal. The bias of “round half up” is more intuitive in binary, as the last bit can only be 0 or 1: if it’s 0, no change is needed; if it’s 1, it triggers “rounding up” and carries over 1. Thus, when a binary number is rounded by “round half up” to discard its last bit, the result will invariably be greater than or equal to the original number. Therefore, “round-to-even” is also needed to eliminate this bias.

BF16 Addition
#

Next, let’s review the BF16 format. BF16 uses 16 binary bits to represent a floating-point number, with 1 bit for the sign, 7 for the mantissa, and 8 for the exponent. The 8-bit exponent design gives it the same dynamic range as FP32 (1 sign bit, 23 mantissa bits, 8 exponent bits), making it the primary floating-point format for LLM training today.

BF16 retains more exponent bits, but at the cost of fewer mantissa bits, resulting in lower precision. To mitigate accumulated errors caused by low precision, BF16 operations adopt a strategy of “BF16 multiplication, FP32 addition.” This means that BF16 numbers are accumulated by first converting them to FP32, performing addition in FP32 space to get an FP32 result, and then converting back to BF16.

Now let’s consider adding two BF16 numbers with the same sign and exponent. Why choose numbers with the same exponent for analysis? Because we want to estimate error, and identical exponents mean the numbers are of the same magnitude, making it most likely to produce the largest error after addition. For example, if two numbers being added differ by a factor of 100, then even if I simply return the larger one, the error would be no more than 1%. So, the maximum error often occurs when adding numbers of the same order of magnitude.

Adding two BF16 numbers with the same sign and exponent will inevitably lead to a carry. For instance, “1.0000001 + 1.0000100 = 10.0000101 = 1.00000101 × 10”. In this case, the exponent needs to be incremented by 1, and the last bit “1” must be discarded to convert it back to BF16 format. As discussed in the previous section, if the last bit is discarded using “round half up,” a positive bias would be introduced. However, we already know that scientists discovered this bias long ago and thus proposed “round-to-even” to eliminate it.

Two Large, One Small
#

So far, all results are within controllable and expected ranges, with no bias yet observed. However, as expected, the unexpected occurred.

Now let’s consider adding three numbers with the same sign, characterized by two numbers having the same large exponent, and the third number being very small. For example, building on our previous example “1.0000001 + 1.0000100”, if we add “0.0000000001”, we get “1.0000001 + 1.0000100 + 0.0000000001 = 10.0000101001 = 1.00000101001 × 10”.

Originally, adding the two numbers yielded “1.00000101 × 10”, and discarding the last bit would trigger “round-to-even”, resulting in “1.0000010 × 10”. But now, with an extra minuscule number, the mantissa to be discarded when converting to BF16 becomes “1001”, which is greater than the halfway point. Consequently, the round-up principle is triggered, resulting in “1.0000011 × 10”. From the perspective of the original two-number addition, the appearance of the third minuscule number disrupts the “round-to-even” rule, causing the positive bias to reappear!

Of course, the conditions for this situation appear quite stringent. Firstly, all three numbers must have the same sign. Secondly, they must satisfy the “two large, one small” condition, where the two large numbers are just enough to trigger a carry, and the small number is so tiny that it only affects the mantissa of FP32 (i.e., the 9th to 23rd mantissa bits). In this way, the small number itself, when rounded, introduces little error. However, its presence just happens to disrupt the “round-to-even” rule for the two large numbers, thereby introducing a one-sided bias.

Tailor-Made
#

Can such stringent conditions actually occur in practice? It’s indeed not easy under general circumstances, but for Attention, this seems like a “tailor-made” bug!

Let’s take an element from a specific row and column of $\bar{\boldsymbol{P}}\boldsymbol{V}$, which can be written as

$$ \sum_{i=1}^n \bar{p}_i v_i $$

where $\bar{p}_i = \exp(s_i - \max(s_i))\leq 1$. We know that Softmax Attention’s characteristic is its ability to “focus attention.” This means attention might concentrate on a limited number of tokens, manifested in $\bar{p}_i$ as a few tokens having $\bar{p}_i$ close to 1, while the rest are very close to 0. However, due to the exponential function, they cannot be precisely zero (unless they underflow BF16’s representable range).

Furthermore, as layers stack and training progresses, the input $\boldsymbol{V}$ may exhibit “anisotropy.” One manifestation of this is an uneven distribution of signs (positive/negative) in certain dimensions. Without loss of generality, let’s assume most $v_i$ values are positive (the same logic applies to negative numbers) and are roughly of the same magnitude. Then, the sum $\sum_{i=1}^n \bar{p}_i v_i$ can be divided into two parts: a few $\bar{p}_i$ values that are close to 1, multiplied by their respective $v_i$, forming the dominant terms of the sum; and the remaining terms, which are the products of most $\bar{p}_i$ values (close to 0) with their $v_i$ values, forming the negligible remainder.

Thus, with “the right time and place” (perfect conditions), the bug mentioned in the previous section is perfectly triggered: most terms are positive, and the sum of dominant terms satisfies the carry condition. The remaining terms are minuscule, only affecting the least significant mantissa bits of FP32, precisely disrupting the “round-to-even” rule and thus introducing bias. Finally, due to “focused attention,” the number of dominant terms is not large, so carries are not too frequent (the more bits discarded, the smaller the bias), keeping the bias within a significant range!

Doesn’t this combination make it a “tailor-made bug” for Attention?

Eliminating the Remainder
#

After understanding the root cause of the problem, let’s consider how to solve it.

Superficially, the bias is caused by minuscule remainder terms disrupting “round-to-even.” However, upon deeper thought, we realize the fundamental reason is that the current set of theoretically unbiased deterministic rounding rules (“round half up” + “round-to-even”) is too fragile, easily leading to bias due to various perturbations. The most ideal radical solution is Stochastic Rounding, which involves rounding up/down probabilistically. This maximally avoids biases caused by small perturbations.

However, Stochastic Rounding does not have efficient hardware-level implementations, so most hardware’s matrix multiplication operators currently do not include Stochastic Rounding. Therefore, the original paper chose to confront the problem directly, with an approach I call “eliminating the remainder.” Specifically, when a certain trigger condition is detected, we modify the Attention computation formula to

$$ \boldsymbol{O} = \frac{\exp(\boldsymbol{S})\boldsymbol{V}}{\exp(\boldsymbol{S})\boldsymbol{1}} = \frac{\exp(\boldsymbol{S} - \beta\boldsymbol{S}_{\max})\boldsymbol{V}}{\exp(\boldsymbol{S}- \beta\boldsymbol{S}_{\max})\boldsymbol{1}} $$

where $\beta > 1$. In this way, each term needs to be additionally divided by $\exp((\beta-1)\boldsymbol{S}_{\max})$, which is not a small number (the paper sets $\beta \geq 2$). Consequently, the originally minuscule remainder terms are prone to underflow to zero and vanish. Then, “round-to-even” resumes its function, thereby eliminating the bias.

So, what are the detection conditions? The original paper’s approach is relatively simple: the modification is triggered when a row of matrix $\boldsymbol{S}$ contains two or more maximum values, meaning at least two $\bar{p}_i$ terms are 1. However, I believe there’s considerable room for adjustment here, leaving it as a direction for improvement. It’s also worth noting that Flash Attention is computed block-wise, so this detection condition and modification are also applied block-by-block. Details can be found in the code in the original paper’s appendix.

Further Thoughts
#

Overall, the paper offers a relatively unique perspective for understanding phenomena like MaxLogit explosion. It can explain some aspects but doesn’t cover the full picture, and it leaves many points worthy of thought (and potential criticism).

Firstly, the paper’s analysis of Attention bias relies on the anisotropy of $\boldsymbol{V}$. This might explain why issues like MaxLogit explosion only appear in the second Attention layer: because the input to the first Attention layer is the Embedding, which is relatively less prone to anisotropy; whereas the input to the second and subsequent Attention layers has passed through previous Attention layers, potentially inherently exhibiting anisotropy (reference).

However, this cannot explain why MaxLogit explosion occurs only in specific layers. For example, the paper’s experimental phenomenon shows problems only in the second layer, while K2’s results show problems in layers 2-4. Similarly, this obviously cannot explain why Muon is more prone to MaxLogit explosion than Adam (also observed in K2). Therefore, this should be a comprehensive result of multiple factors including architecture, optimizer, and low precision. Focusing solely on precision issues provides an incomplete picture.

Furthermore, another question worth pondering is causality. Another condition for the Attention bias described in the paper is that attention is concentrated on a few tokens. At this point, intervening in the Attention computation successfully prevents subsequent anomalies. However, I observed a normally trained small model where attention was not as concentrated as imagined; for instance, the average probability of Top-1 was less than 0.2, and the cumulative probability for Top-400 only reached 0.9 (with a training length of 4096).

So, is Attention bias the “cause” or the “effect” of training collapse? In other words, when “attention is concentrated on a few tokens,” could it indicate that the model has already entered a state of collapse? If intervention only happens at that point, might it be “too late”? For example, while some anomalies might be prevented in terms of metrics, could it be that the model is no longer able to scale effectively? These questions remain unknown for now.

Summary (formatted)
#

This article shares an analysis paper on the bias in low-precision Attention computation, and at the same time, I took this opportunity to review the foundational concepts of low-precision computation.

@online{kexuefm-11371,
        title={Low-Precision Attention May Suffer from Biased Rounding Errors},
        author={苏剑林},
        year={2025},
        month={10},
        url={\url{https://kexue.fm/archives/11371}},
}