This is a gemini-2.5-flash translation of a Chinese article.
It has NOT been vetted for errors. You should have the original article open in a parallel tab at all times.
By Su Jianlin | 2025-11-06 | 446 readers
Let $z_1,z_2,\cdots,z_n$ be $n$ random variables independently and identically sampled from a standard normal distribution. From these, we can construct many derived random variables. For example, $z_1+z_2+\cdots+z_n$, which still follows a normal distribution, or $z_1^2+z_2^2+\cdots+z_n^2$, which follows a chi-squared distribution. In this article, we will focus on the distribution information of its maximum value, $z_{\max} = \max\{z_1,z_2,\cdots,z_n\}$, especially its mathematical expectation $\mathbb{E}[z_{\max}]$.
First Look at the Conclusion#
The basic estimation result for $\mathbb{E}[z_{\max}]$ is:
$$ \begin{equation}\mathbb{E}[z_{\max}]\sim \sqrt{2\log n}\end{equation} $$Let $z_1,z_2,\cdots,z_n\sim\mathcal{N}(0,1)$, and $z_{\max} = \max\{z_1,z_2,\cdots,z_n\}$. Then
Here, the meaning of $\sim$ in Equation (1) is:
$$ \begin{equation}\lim_{n\to\infty} \frac{\mathbb{E}[z_{\max}]}{\sqrt{2\log n}} = 1\end{equation} $$Thus, it can be seen that this result becomes relatively more accurate as $n$ increases. A more precise result is
$$ \begin{equation}\mathbb{E}[z_{\max}]\sim \sqrt{2\log \frac{n}{\sqrt{2\pi}}}\end{equation} $$We can verify them using Numpy:
import numpy as np
n = 4096
z = np.random.randn(10000, n)
E_z_max = z.max(axis=1).mean() # ≈ 3.63
approx1 = np.sqrt(2 * np.log(n)) # ≈ 4.08
approx2 = np.sqrt(2 * np.log(n / np.sqrt(2 * np.pi))) # ≈ 3.85
Quick Upper Bound#
For the above conclusion, this article will provide three proofs. The first proof comes from @Sivaraman’s answer in the post 《Expectation of the maximum of gaussian random variables》. It essentially only proves $\mathbb{E}[z_{\max}] \leq \sqrt{2\log n}$, but the proof itself is quite elegant and worth studying.
The proof cleverly utilizes the convexity of $\exp$: for any $t > 0$, we can write
$$ \begin{equation}\exp(t\mathbb{E}[z_{\max}]) \leq \mathbb{E}[\exp(t z_{\max})] = \mathbb{E}[\max_i \exp(t z_i)]\leq \sum_{i=1}^n\mathbb{E}[\exp(t z_i)] = n \exp(t^2 / 2)\end{equation} $$The first $\leq$ is based on Jensen’s inequality, and the second $\leq$ replaces the maximum with a sum. Now, taking the logarithm on both sides and rearranging, we get
$$ \begin{equation}\mathbb{E}[z_{\max}] \leq \frac{\log n}{t} + \frac{t}{2}\end{equation} $$Note that this holds for any $t > 0$, so we can choose the $t$ that minimizes the right-hand side to achieve the highest approximation accuracy. By basic inequality principles, the minimum of the right-hand side is attained at $t=\sqrt{2\log n}$. Substituting this into the above equation, we get
$$ \begin{equation}\mathbb{E}[z_{\max}] \leq \sqrt{2\log n}\end{equation} $$The characteristic of this derivation is its simplicity and speed, requiring little prior knowledge. However, theoretically, it only provides an upper bound, although it is surprisingly accurate, coinciding with the asymptotic result.
Conventional Approach#
For those accustomed to systematically deriving formulas (like the author), the above derivation might feel somewhat “unconventional”, as the standard method should first involve finding the probability density function of $z_{\max}$ and then calculating the expectation through integration. In this section, we will proceed with this approach.
Probability Density#
For a 1-dimensional distribution, the probability density function (PDF) and cumulative distribution function (CDF) are “two sides of the same coin” that can be mutually derived. Finding the probability density of $z_{\max}$ requires the aid of the cumulative distribution function. The probability density of a standard normal distribution is $p(z)=\exp(-z^2/2)/\sqrt{2\pi}$. The cumulative distribution function $\Phi(z)=\int_{-\infty}^z p(x)dx$ is a non-elementary function, whose meaning is the probability that the random variable is less than or equal to $z$.
To find the cumulative distribution function of $z_{\max}$, i.e., $P(z_{\max} \leq z)$, it is easy to see that $z_{\max} \leq z$ is equivalent to $z_1 \leq z, z_2\leq z, \cdots, z_n\leq z$ holding simultaneously. Since $z_i$ are independently sampled, the probability of them all holding simultaneously is the product of their individual probabilities:
$$ \begin{equation}P(z_{\max} \leq z) = \prod_{i=1}^n P(z_i \leq z) = [\Phi(z)]^n \end{equation} $$Thus, the cumulative distribution function of $z_{\max}$ is $[\Phi(z)]^n$, a very concise result. Note that we have not yet used the condition that $z$ follows a normal distribution, so this is in fact a general result: for $n$ numbers independently and identically sampled from any distribution, the cumulative distribution function of their maximum value is the $n$-th power of the original distribution’s cumulative distribution function. Now, differentiating it gives the probability density function $p_{\max}(z)$ of $z_{\max}$:
$$ \begin{equation}p_{\max}(z) = n[\Phi(z)]^{n-1} p(z) \end{equation} $$The plots of $p_{\max}(z)$ for $n=50,100,200$ are shown below:
Plots of $p_{\max}(z)$ for different values of n
Laplace Approximation#
With the probability density function, we can theoretically calculate the expectation by integration:
$$ \begin{equation}\mathbb{E}[z_{\max}] = \int_{-\infty}^{\infty} z\, p_{\max}(z) dz\end{equation} $$However, it is evident that this integral is difficult to compute, so we need to find an approximation that is easier to calculate. From the plots above, it can be seen that the shape of $p_{\max}(z)$ also resembles a bell-shaped curve, similar to a normal distribution. Thus, it is natural to seek a normal distribution approximation, which is also called “Laplace Approximation”.
The first step in finding a normal distribution approximation is to find the maximum point of the bell-shaped curve, and then expand $\log p_{\max}(z)$ to the second order around that point. If the goal is only to find the mean, then finding the maximum point is sufficient, as the mean of a normal distribution is its mode (the maximum point of its probability density). To find this maximum point $z_*$, we first compute $\log p_{\max}(z)$:
$$ \begin{equation}\begin{aligned} \log p_{\max}(z) =&\,\log n + (n-1)\log \Phi(z) + \log p(z) \\ =&\,\log n + (n-1)\log \Phi(z) - \frac{z^2}{2} - \frac{1}{2}\log 2\pi \end{aligned}\end{equation} $$Differentiating with respect to $z$, we get
$$ \begin{equation}\frac{d}{dz}\log p_{\max}(z) = (n-1) \frac{p(z)}{\Phi(z)} - z = \frac{(n-1)\exp(-z^2/2)}{\Phi(z) \sqrt{2\pi}} - z\end{equation} $$Setting it to 0, we can rearrange to get
$$ \begin{equation}z_* = \sqrt{2\log\frac{n-1}{z_*\Phi(z_*)\sqrt{2\pi}}}\end{equation} $$Approximate Solution#
The next task is to solve Equation (12). Of course, we do not need to find an exact solution, but only an asymptotic estimate. Note that $z\Phi(z)$ is already greater than 1 when $z\geq 1.15$. Since we are considering an asymptotic solution, it is natural to assume $z_* \geq 1.15$, so we have
$$ \begin{equation}z_* < \sqrt{2\log\frac{n-1}{\sqrt{2\pi}}}\end{equation} $$Substituting this result back into Equation (12) and using $\Phi(z) < 1$, we get
$$ \begin{equation}z_* > \sqrt{2\log\frac{n-1}{\sqrt{2\log\frac{n-1}{\sqrt{2\pi}}}\sqrt{2\pi}}}\end{equation} $$From these two bounds, we can obtain
$$ \begin{equation}z_* \sim \sqrt{2\log\frac{n-1}{\sqrt{2\pi}}}\sim \sqrt{2\log\frac{n}{\sqrt{2\pi}}}\sim \sqrt{2\log n}\end{equation} $$This is the asymptotic result for $\mathbb{E}[z_{\max}]$ that we are seeking. For further discussion on this problem, please refer to the Fisher–Tippett–Gnedenko theorem.
Inverse Transform Sampling#
The last proof is based on the inverse transform sampling approach: let the cumulative distribution function of a 1-dimensional distribution be $\Phi(z)$, and its inverse function be $\Phi^{-1}(z)$. Then, one way to sample from this distribution is
$$ \begin{equation}z = \Phi^{-1}(\varepsilon),\qquad \varepsilon\sim U(0,1)\end{equation} $$That is, the uniform distribution can be transformed into the target distribution through the inverse cumulative distribution function. So, if we sample $n$ points $z_1,z_2,\cdots,z_n$ from an arbitrary target distribution, it is equivalent to saying that $\Phi(z_1),\Phi(z_2),\cdots,\Phi(z_n)$ are $n$ points sampled from $U(0,1)$. If we want to find $\mathbb{E}[z_{\max}]$, we can approximately assume
$$ \begin{equation}\mathbb{E}[z_{\max}]\approx\Phi^{-1}(\mathbb{E}[\Phi(z_{\max})])\end{equation} $$That is, we first find the corresponding expectation in $U(0,1)$ and then transform it back using $\Phi^{-1}$. It is not hard to guess that for $n$ points sampled from $U(0,1)$, the average of their maximum values will be approximately $\frac{n}{n+1}$ [dividing the interval $(0,1)$ into $n+1$ equal parts, there are exactly $n$ points, taking the maximum one]. So we have $\mathbb{E}[z_{\max}]\approx \Phi^{-1}(\frac{n}{n+1})$. This is a general result. Next, we will combine it with a specific cumulative distribution function to obtain a more explicit solution.
For the standard normal distribution, we have $\Phi(z) = \frac{1}{2} + \frac{1}{2}\erf\left(\frac{z}{\sqrt{2}}\right) = 1 - \frac{1}{2}\erfc\left(\frac{z}{\sqrt{2}}\right)$. The Erfc function has an asymptotic form $\erfc(z)\sim \frac{\exp(-z^2)}{z\sqrt{\pi}}$ (which can be derived by integration by parts). Substituting this into $\Phi(z)$ yields
$$ \begin{equation}\Phi(z)\sim 1 - \frac{\exp(-z^2/2)}{z\sqrt{2\pi}}\end{equation} $$Therefore, finding $\Phi^{-1}(\frac{n}{n+1})$ is approximately equivalent to solving the equation
$$ \begin{equation}\frac{\exp(-z^2/2)}{z\sqrt{2\pi}} = \frac{1}{n+1}\end{equation} $$This is largely similar to the equation in the previous section. Following the solution process from the previous section, we can obtain
$$ \begin{equation}\mathbb{E}[z_{\max}] \sim \sqrt{2\log\frac{n+1}{\sqrt{2\pi}}}\sim \sqrt{2\log\frac{n}{\sqrt{2\pi}}}\sim \sqrt{2\log n}\end{equation} $$Application Example#
In the article 《Biased Rounding Errors Might Exist in Low-Precision Attention》, we introduced a mechanism that can lead to computational bias in low-precision Attention. One condition for this is the simultaneous existence of multiple maximum values in the same row of Attention Logits. Is this condition easily met? Using the results from this article, we can make an estimate of its probability of occurrence.
Assume a row has $n$ Logits, all independently and identically sampled from $\mathcal{N}(0,1)$. One could also consider general means and variances, but this would not change the subsequent conclusion. Of course, the actual distribution may not be normal, but this serves as a valid basic result. The question now is: if these $n$ Logits are all converted to BF16 format, what is the probability that at least two identical maximum values appear?
According to the aforementioned results, the maximum value of these $n$ Logits is approximately $\nu = \sqrt{2\log\frac{n-1}{\sqrt{2\pi}}}$. BF16 only has a 7-bit mantissa, meaning a relative precision of $2^{-7}=1/128$. Therefore, if just one of the remaining $n-1$ Logits falls within the interval $(\frac{127}{128}\nu,\nu]$, it can be considered that two identical maximum values have appeared in BF16. According to the meaning of the probability density function, the probability of a single sample falling within this interval is $p(\nu)\frac{\nu}{128}$. Thus, the probability that at least one of the $n-1$ numbers falls within this interval is
$$ \begin{equation}1 - \left(1 - p(\nu)\frac{\nu}{128}\right)^{n-1} = 1 - \left(1 - \frac{\nu/128}{n-1}\right)^{n-1}\approx 1 - e^{-\nu/128} \approx \frac{\nu}{128}\end{equation} $$Note that when we first determine the maximum value, the remaining $n-1$ numbers can no longer strictly be considered independently and identically sampled from a standard normal distribution. Calculating them as such here will definitely result in an underestimate for sufficiently large $n$, but the author believes it is usable as a simple approximation.
Comparison with numerical simulation results:
import jax
import jax.numpy as jnp
def proba_of_multi_max(n, T=100, seed=42):
p, key = 0, jax.random.key(seed)
for i in range(T):
key, subkey = jax.random.split(key)
logits = jax.random.normal(subkey, (10000, n)).astype('bfloat16')
p += ((logits == logits.max(axis=1, keepdims=True)).sum(axis=1) > 1).mean()
return p / T
def approx_result(n):
return jnp.sqrt(2 * jnp.log(n / jnp.sqrt(2 * jnp.pi))) / 128
proba_of_multi_max(128) # 0.018246
approx_result(128) # 0.0219115
proba_of_multi_max(4096) # 0.028279
approx_result(4096) # 0.03005291
proba_of_multi_max(65536) # 0.05296699
approx_result(65536) # 0.03523674
It can be seen that even with a sequence length of only 128, there is approximately a 2% probability of repeated maximum values, which is very significant. This is because Flash Attention is computed block-wise, and the typical block length is 128. With a 2% probability of repeated maximums for every 128 Logits, the probability of at least one occurrence of repeated maximum values during the entire Attention computation for a sample is on the order of $1 - 0.98^{n^2/128}$ (don’t forget that the size of the Logits matrix is the square of the sequence length). If we substitute $n=4096$, this probability already approaches 100%.
There’s a small detail here: in actual Attention computation, the Logits matrix is typically not in BF16 format. Instead, it is converted to BF16 after subtracting the maximum and then exponentiating. In this case, the maximum value becomes 1, and the problem is equivalent to calculating the probability that at least two 1s appear in each row of that matrix. However, this detailed version of the result will not have a significant difference compared to directly converting the Logits matrix to BF16.
Summary (formatted)#
This article estimated the mathematical expectation of the maximum of $n$ normal random variables using three different methods, and based on the results, provided a simple estimate for the probability of repeated maximum values occurring in low-precision Attention matrices.
@online{kexuefm-11390,
title={Asymptotic Estimate for the Maximum of n Normal Random Variables},
author={苏剑林},
year={2025},
month={11},
url={\url{https://kexue.fm/archives/11390}},
}