Skip to main content

Rethinking the Relationship Between Learning Rate and Batch Size (Part 1) - Current Status

·1638 words
Table of Contents

This is a gemini-2.5-flash translation of a Chinese article.

It has NOT been vetted for errors. You should have the original article open in a parallel tab at all times.

By Su Jianlin | 2025-09-01 | 1140 Readers

In previous articles, 《How should the learning rate change as Batch Size increases?》 and 《How Adam’s epsilon affects the Learning Rate Scaling Law?》, we theoretically discussed the scaling law of learning rate with Batch Size. A classic part of this is the second-order expansion analysis proposed by OpenAI. However, when dealing with non-SGD optimizers, the calculation process for this analysis method often becomes quite complex, feeling intractable.

In the following articles, I will reorganize and rethink the relevant details from the aforementioned articles. I will attempt to simplify some of the derivation steps, provide a more general and lightweight derivation path, and explore the possibility of extending it to the Muon optimizer.

Main Idea of the Method
#

First, let’s review the previous analysis method. In 《How should the learning rate change as Batch Size increases?》, we introduced various approaches to analyze the relationship between learning rate and Batch Size. The second-order approximation analysis proposed by OpenAI in 《An Empirical Model of Large-Batch Training》 occupied a major part, and this article also adopts the same approach.

Next, let’s introduce some notation. Let the loss function be $\mathcal{L}(\boldsymbol{w})$, where $\boldsymbol{w}\in\mathbb{R}^N$ is the parameter vector, and $\boldsymbol{g}$ is its gradient. Note that the ideal loss function is calculated as an expectation over all training samples, but in practice, we can only sample a batch. This introduces randomness into the gradient. We denote the gradient of a single sample as $\tilde{\boldsymbol{g}}$, its mean is $\boldsymbol{g}$, and its covariance matrix is $\boldsymbol{\Sigma}$. When the Batch Size is $B$, the gradient is denoted as $\tilde{\boldsymbol{g}}_B$, its mean is still $\boldsymbol{g}$, but its covariance matrix becomes $\boldsymbol{\Sigma}/B$.

Furthermore, let the current learning rate be $\eta$, and the update vector be $\tilde{\boldsymbol{\varphi}}_B$. Then the updated loss function will be:

$$ \begin{aligned} \mathcal{L}(\boldsymbol{w} - \eta\tilde{\boldsymbol{\varphi}}_B) \approx&\, \mathcal{L}(\boldsymbol{w}) - \eta \tilde{\boldsymbol{\varphi}}_B^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2\tilde{\boldsymbol{\varphi}}_B^{\top}\boldsymbol{H}\tilde{\boldsymbol{\varphi}}_B \\ =&\, \mathcal{L}(\boldsymbol{w}) - \eta \tilde{\boldsymbol{\varphi}}_B^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2\operatorname{tr}(\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}\boldsymbol{H}) \end{aligned} $$

On the right side, we performed a second-order Taylor expansion. $\boldsymbol{H}$ is the Hessian matrix, and $\operatorname{tr}$ is the trace of the matrix. The second equality uses the identity $\operatorname{tr}(\boldsymbol{A}\boldsymbol{B})=\operatorname{tr}(\boldsymbol{B}\boldsymbol{A})$. To obtain a deterministic result, we take the expectation of both sides:

$$ \mathbb{E}[\mathcal{L}(\boldsymbol{w} - \eta\tilde{\boldsymbol{\varphi}}_B)] \approx \mathcal{L}(\boldsymbol{w}) - \eta\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2 \operatorname{tr}(\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]\boldsymbol{H}) $$

We view the right-hand side as a quadratic function of $\eta$, and assume the coefficient of the quadratic term is positive (a stronger assumption is that the $\boldsymbol{H}$ matrix is positive definite). Then we can find the minimum point:

$$ \eta^* \approx \frac{\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]^{\top}\boldsymbol{g}}{\operatorname{tr}(\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]\boldsymbol{H})} $$

This is the learning rate that leads to the fastest descent of the loss function, on average, representing the theoretical optimal solution for the learning rate. Our task is to calculate $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]$ and $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]$ for a specific $\tilde{\boldsymbol{\varphi}}_B$, and then extract its relationship with Batch Size (i.e., $B$) from the above equation.

Warm-up Exercise
#

As a first example, we naturally consider the simplest SGD. In this case, $\tilde{\boldsymbol{\varphi}}_B=\tilde{\boldsymbol{g}}_B$. It can then be simply derived that $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]=\boldsymbol{g}$ and $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]=\boldsymbol{g}\boldsymbol{g}^{\top} + \boldsymbol{\Sigma}/B$. Thus we have:

$$ \eta^* \approx \frac{\boldsymbol{g}^{\top}\boldsymbol{g}}{\operatorname{tr}((\boldsymbol{g}\boldsymbol{g}^{\top} + \boldsymbol{\Sigma}/B)\boldsymbol{H})} = \frac{\boldsymbol{g}^{\top}\boldsymbol{g}}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g} + \operatorname{tr}(\boldsymbol{\Sigma}\boldsymbol{H})/B} = \frac{\eta_{\max}}{1 + \mathcal{B}_{\text{noise}}/B} $$

where

$$ \eta_{\max} = \frac{\boldsymbol{g}^{\top}\boldsymbol{g}}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}},\qquad\mathcal{B}_{\text{noise}} = \frac{\operatorname{tr}(\boldsymbol{\Sigma}\boldsymbol{H})}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}} $$

For this result, we can have multiple interpretations. First, it is a monotonically increasing function with an upper bound, $\eta_{\max}$. This indicates that the learning rate cannot increase indefinitely, which aligns better with our intuition compared to simple linear or square-root scaling laws. When $B \ll \mathcal{B}_{\text{noise}}$, we have:

$$ \eta^* \approx \frac{\eta_{\max}}{1 + \mathcal{B}_{\text{noise}}/B} \approx \frac{\eta_{\max}}{\mathcal{B}_{\text{noise}}/B} = \eta_{\max} B / \mathcal{B}_{\text{noise}} $$

This indicates that when Batch Size is small, the learning rate of SGD indeed shows a linear relationship with Batch Size, and also implies that $\mathcal{B}_{\text{noise}}$ is a crucial statistic. However, the definition of $\mathcal{B}_{\text{noise}}$ depends on the Hessian matrix $\boldsymbol{H}$, which is almost impossible to precisely calculate in LLMs. Thus, in practice, we usually assume it’s an identity matrix (or a multiple thereof), leading to a simplified form:

$$ \mathcal{B}_{\text{simple}} = \frac{\operatorname{tr}(\boldsymbol{\Sigma})}{\boldsymbol{g}^{\top}\boldsymbol{g}} $$

This result has the form of noise intensity ($\operatorname{tr}(\boldsymbol{\Sigma})$) divided by signal intensity ($\boldsymbol{g}^{\top}\boldsymbol{g}$), which is actually the inverse of the signal-to-noise ratio. It indicates that the smaller the signal-to-noise ratio, the larger the Batch Size required to utilize the same $\eta_{\max}$, which also aligns with our intuition. $\operatorname{tr}(\boldsymbol{\Sigma})$ only depends on the diagonal elements of $\boldsymbol{\Sigma}$, implying that we only need to estimate the mean and variance of each parameter independently, which is feasible in practice.

Data Efficiency
#

Beyond the direct relationship between learning rate and Batch Size, I believe the asymptotic relationship between the total amount of training data and the number of training steps derived from it is also an exciting and essential part to learn. In particular, this conclusion appears to be more general than the learning rate relationship for SGD, because, as we will see later, SignSGD can also yield a conclusion of the same form, even though its learning rate scaling law is not the same as the SGD equation.

The discussion of this part in the original paper is relatively complex; the following derivation has been simplified by me. Specifically, substituting $\eta^*$ back into $\mathcal{L}(\boldsymbol{w} - \eta\tilde{\boldsymbol{g}}_B)$, we get:

$$ \overline{\Delta\mathcal{L}} = \mathcal{L}(\boldsymbol{w}) - \mathbb{E}[\mathcal{L}(\boldsymbol{w} - \eta^*\tilde{\boldsymbol{g}}_B)] \approx \frac{\Delta\mathcal{L}_{\max}}{1 + \mathcal{B}_{\text{noise}}/B} $$

where $\Delta\mathcal{L}_{\max} = \frac{(\boldsymbol{g}^{\top}\boldsymbol{g})^2}{2\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}}$. How should we interpret this result? First, it is a monotonically increasing function of $B$. As $B\to\infty$, it equals $\Delta\mathcal{L}_{\max}$. In other words, if we could use an infinitely large Batch Size, the loss reduction per step would be $\Delta\mathcal{L}_{\max}$, and the number of training steps required would be minimal, denoted as $S_{\min}$.

If the Batch Size is finite, the average loss reduction per step is only $\overline{\Delta\mathcal{L}}$. This means that, on average, we need to take $1 + \mathcal{B}_{\text{noise}}/B$ steps to achieve the same loss reduction as 1 step with an infinitely large Batch Size. Therefore, to reach the same loss, we need to train for $S = (1 + \mathcal{B}_{\text{noise}}/B)S_{\min}$ steps.

Since the Batch Size is $B$, it is easy to derive that the total amount of data consumed for training is $E = BS = (B + \mathcal{B}_{\text{noise}})S_{\min}$. From this result, it can be seen that after increasing the Batch Size, to achieve the same effect, we also need to appropriately increase the data amount $E$. When $B\to 0$, the minimum amount of data required is $E_{\min} = \mathcal{B}_{\text{noise}}S_{\min}$. Using these notations, we can write:

$$ \left(\frac{S}{S_{\min}} - 1\right)\left(\frac{E}{E_{\min}} - 1\right) = 1 $$

This is the classic relationship between training data volume and training steps. It has two parameters, $S_{\min},E_{\min}$. We can also estimate $S_{\min},E_{\min}$ by experimentally searching for multiple $(S,E)$ pairs to fit the above equation, and then estimate $\mathcal{B}_{\text{noise}} = E_{\min} / S_{\min}$. For more analytical details, please refer back to the previous article 《How should the learning rate change as Batch Size increases?》 or OpenAI’s original paper 《An Empirical Model of Large-Batch Training》.

Difficulty Analysis
#

While much has been written above, it all remains within the realm of SGD. From a computational perspective, SGD is trivial; the real complexity arises when $\tilde{\boldsymbol{\varphi}}_B$ depends nonlinearly on $\tilde{\boldsymbol{g}}_B$. For instance, SignSGD corresponds to $\tilde{\boldsymbol{\varphi}}_B=\sign(\tilde{\boldsymbol{g}}_B)$, which is often used as an approximation for Adam in theoretical analysis. A more accurate approximation is SoftSignSGD, which considers $\epsilon$, and we attempted to analyze it in 《How Adam’s epsilon affects the Learning Rate Scaling Law?》.

In these nonlinear scenarios, the calculation of $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]$ and $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]$ is often quite difficult, even if we assume a simple normal distribution for $\tilde{\boldsymbol{g}}_B$ (note that in the analysis of SGD, we did not need to make a normal distribution assumption about its form). For example, in a previous article, for SignSGD where $\tilde{\boldsymbol{\varphi}}_B=\sign(\tilde{\boldsymbol{g}}_B)$, to calculate $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]$, we went through the following steps:

  1. Assume that the components of $\tilde{\boldsymbol{g}}_B$ are independent, simplifying the problem to the expectation of a single component $\tilde{\varphi}_B=\sign(\tilde{g}_B)$ (not bolded);
  2. Assume that $\tilde{g}_B$ (now a scalar) follows a normal distribution, then $\mathbb{E}[\tilde{\varphi}_B]$ can be calculated, with the answer expressed using the $\erf$ function;
  3. Approximate the $\erf$ function with a function of the form $x/\sqrt{x^2+c}$ to simplify the result.

In other words, we had to go through a series of convoluted steps to barely obtain an approximate result that could be further analyzed (this process first appeared in Tencent’s paper 《Surge Phenomenon in Optimal Learning Rate and Batch Size Scaling》). And this was already considered simple, because for SoftSignSGD, it becomes even more complex:

  1. Assume that the components of $\tilde{\boldsymbol{g}}_B$ are independent, simplifying the problem to the expectation of a single component $\tilde{\varphi}_B=\softsign(\tilde{g}_B, \epsilon)$;
  2. Approximate the $\softsign$ function with a piecewise linear function to calculate the integral below;
  3. Assume that $\tilde{g}_B$ follows a normal distribution, and combine with the approximation from step 2, $\mathbb{E}[\tilde{\varphi}_B]$ can be calculated, yielding a complex function involving $\erf$;
  4. Approximate the complex function with a function of the form $x/\sqrt{x^2+c}$ to simplify the result.

The story isn’t over yet. After so much effort and so many assumptions, we barely managed to calculate $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]$. Next, we still need to calculate $\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]$, which is often even more complex (SignSGD is an exception because $\sign(x)^2$ is always 1, making it simpler). However, the computational complexity is secondary; the main issue is that these steps seem to lack any generalizable pattern, appearing to be specific-case analyses, which makes one feel very weary.

Summary (formatted)
#

To avoid making the article too long, this post will end here, primarily providing a brief review of existing analytical results and computational difficulties. In the next article, I will introduce some attempts I’ve made to reduce the cognitive load during the derivation process.

@online{kexuefm-11260,
        title={Rethinking the Relationship Between Learning Rate and Batch Size (Part 1) - Current Status},
        author={苏剑林},
        year={2025},
        month={09},
        url={\url{https://kexue.fm/archives/11260}},
}