Skip to main content

How Should the Learning Rate Change When Batch Size Increases?

·4065 words
Table of Contents

This is a gemini-2.5-flash translation of a Chinese article.

It has NOT been vetted for errors. You should have the original article open in a parallel tab at all times.

With the rapid advancements in computational power, an increasing number of scenarios hope to achieve “compute for time,” i.e., by stacking compute to shorten model training time. Ideally, we hope that investing $n$ times the computational power would reduce the time to achieve the same effect by $1/n$, at which point the total computational cost remains consistent. This “hope” seems very reasonable and natural, but it is actually non-trivial. Even if we disregard bottlenecks like communication, when computational power exceeds a certain scale or the model is smaller than a certain scale, increasing computational power often only allows for an increase in Batch Size. However, can increasing Batch Size always shorten training time while maintaining the same performance?

This is the topic we will discuss next: When Batch Size increases, how should various hyperparameters, especially the learning rate, be adjusted to maintain the original training performance and maximize training efficiency? We can also call this the Scaling Law between Batch Size and learning rate.

Variance Perspective
#

Intuitively, when Batch Size increases, the gradient for each Batch becomes more accurate, so we can take larger steps, i.e., increase the learning rate, in order to reach the destination faster and shorten training time. This is generally understandable. The question is, how much of an increase is most appropriate?

Square Root Scaling
#

The earliest answer to this question might be square root scaling, meaning if the Batch Size is increased by $n$ times, the learning rate is increased by $\sqrt{n}$ times. This idea originated from the 2014 paper 《One weird trick for parallelizing convolutional neural networks》, and its derivation principle is to keep the variance of the SGD increment constant. Specifically, let the gradient of a randomly sampled data point be denoted as $\tilde{\boldsymbol{g}}$, with its mean and covariance denoted as $\boldsymbol{g}$ and $\boldsymbol{\Sigma}$, respectively, where $\boldsymbol{g}$ is the gradient for the entire dataset. When we increase the number of samples to $B$, we have:

$$ \tilde{\boldsymbol{g}}_B \triangleq \frac{1}{B}\sum_{i=1}^B \tilde{\boldsymbol{g}}^{(i)},\quad \mathbb{E}[\tilde{\boldsymbol{g}}_B] = \boldsymbol{g},\quad \mathbb{E}[(\tilde{\boldsymbol{g}}_B-\boldsymbol{g})(\tilde{\boldsymbol{g}}_B-\boldsymbol{g})^{\top}]=\frac{\boldsymbol{\Sigma}}{B} $$

This means that increasing the number of samples does not change the mean, while the covariance shrinks to $1/B$. For an SGD optimizer, the increment is $-\eta \tilde{\boldsymbol{g}}_B$, and its covariance is proportional to $\eta^2/B$. We believe that an appropriate amount of noise (not too much, not too little) is necessary during the optimization process. Therefore, when Batch Size $B$ changes, we adjust the learning rate $\eta$ to keep the noise intensity, i.e., the covariance matrix of the increment, constant. From this, we derive:

$$ \frac{\eta^2}{B} = \text{constant}\quad\Rightarrow\quad \eta\propto \sqrt{B} $$

This gives us the square root scaling law between learning rate and Batch Size. Later, 《Train longer, generalize better: closing the generalization gap in large batch training of neural networks》 also agreed with this choice.

Linear Scaling
#

Interestingly, linear scaling, i.e., $\eta\propto B$, often performs better in practice. Even the authors of the aforementioned 《One weird trick for parallelizing convolutional neural networks》, who first proposed square root scaling, also pointed this out in their paper and stated that they could not provide a reasonable explanation.

To some extent, linear scaling aligns more with our intuitive understanding, especially as suggested by 《Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour》: if the gradient directions of $n$ consecutive Batches do not change much, then linear scaling is almost self-evident. However, this assumption is clearly too strong. Relaxing this assumption requires connecting SGD with SDE (Stochastic Differential Equations), which was achieved by 《Stochastic Modified Equations and Dynamics of Stochastic Gradient Algorithms I: Mathematical Foundations》. However, the first paper to point out the scaling relationship between learning rate and Batch Size using this connection should be 《On the Generalization Benefit of Noise in Stochastic Gradient Descent》.

In retrospect, the establishment of this connection is not difficult to understand. Let the model parameters be $\boldsymbol{\theta}$. Then the SGD update rule can be rewritten as:

$$ \boldsymbol{\theta}_{t+1} =\boldsymbol{\theta}_t - \eta \tilde{\boldsymbol{g}}_{B,t} =\boldsymbol{\theta}_t - \eta \boldsymbol{g}_t - \eta (\tilde{\boldsymbol{g}}_{B,t} - \boldsymbol{g}_t) $$

where $\tilde{\boldsymbol{g}}_{B,t} - \boldsymbol{g}_t$ is the gradient noise. So far, we have made no assumptions about the distribution of this noise, only that its mean is $\boldsymbol{0}$ and its covariance is $\boldsymbol{\Sigma}_t/B$. Next, we assume that this noise follows a normal distribution $\mathcal{N}(\boldsymbol{0},\boldsymbol{\Sigma}_t/B)$. Then the above iteration can be further rewritten as:

$$ \begin{aligned}\boldsymbol{\theta}_{t+1} =&\, \boldsymbol{\theta}_t - \eta \boldsymbol{g}_t - \eta (\tilde{\boldsymbol{g}}_{B,t} - \boldsymbol{g}_t) \\[5pt] =&\, \boldsymbol{\theta}_t - \eta \boldsymbol{g}_t - \eta \sqrt{\frac{\boldsymbol{\Sigma}_t}{B}}\boldsymbol{z},\quad \boldsymbol{z}\sim \mathcal{N}(\boldsymbol{0},\boldsymbol{I}) \\[5pt] =&\, \boldsymbol{\theta}_t - \eta \boldsymbol{g}_t - \sqrt{\eta} \sqrt{\frac{\eta\boldsymbol{\Sigma}_t}{B}}\boldsymbol{z},\quad \boldsymbol{z}\sim \mathcal{N}(\boldsymbol{0},\boldsymbol{I}) \end{aligned} $$

This means that the SGD iteration format $\boldsymbol{\theta}_{t+1} =\boldsymbol{\theta}_t - \eta \tilde{\boldsymbol{g}}_{B,t}$ is actually approximately solving the SDE:

$$ d\boldsymbol{\theta} = - \boldsymbol{g}_t dt - \sqrt{\frac{\eta\boldsymbol{\Sigma}_t}{B}}d\boldsymbol{w} $$

Therefore, to ensure that the results do not change significantly when $B$ varies, the form of the above SDE should remain constant, which leads to linear scaling $\eta\propto B$. The most crucial step in this process is that the step size of the noise term in SDE is the square root of the non-noise term, thus isolating an $\sqrt{\eta}$ term. We also discussed this point in 《A Casual Talk on Generative Diffusion Models (V): SDEs as a General Framework》. Simply put, zero-mean Gaussian noise will have some cancellation effect over the long term, so the step size must be increased to manifest the noise effect.

The conclusions above are all based on the SGD optimizer. The paper 《On the SDEs and Scaling Rules for Adaptive Gradient Algorithms》 extended them to optimizers like RMSProp and Adam, resulting in square root scaling. Coincidentally, an earlier paper, 《Large Batch Optimization for Deep Learning: Training BERT in 76 minutes》, also applied square root scaling when testing Adam and its variant LAMB. More content can be found in the blog post 《How to Scale Hyperparameters as Batch Size Increases》.

Directly Facing the Loss
#

It is certain that whether it is square root scaling or linear scaling, they can only hold approximately within a local range, because they both imply the conclusion that “as long as the Batch Size is large enough, the learning rate can be arbitrarily large,” which is clearly impossible. Furthermore, the previous two sections focused on variance, but our fundamental task is to minimize the loss function. Therefore, a loss-function-oriented approach might be more fundamental.

Monotonically Increasing with an Upper Bound
#

A classic work from this perspective is OpenAI’s 《An Empirical Model of Large-Batch Training》, which analyzed the optimal learning rate for SGD through a second-order approximation of the loss function, concluding that “the learning rate increases monotonically with Batch Size but has an upper bound.” The same idea also appeared in the slightly earlier paper 《Dissecting Adam: The Sign, Magnitude and Variance of Stochastic Gradients》, though that paper was not used to analyze the effect of Batch Size.

The most crucial idea throughout the derivation process is to treat the learning rate as an optimization parameter: Let the loss function be $\mathcal{L}(\boldsymbol{\theta})$, and the gradient of the current Batch be $\tilde{\boldsymbol{g}}_B$. Then the loss function after SGD is $\mathcal{L}(\boldsymbol{\theta} - \eta\tilde{\boldsymbol{g}}_B)$. We treat the optimal learning rate problem as an optimization problem:

$$ \eta^* = \mathop{\text{argmin}}_{\eta} \mathbb{E}[\mathcal{L}(\boldsymbol{\theta} - \eta\tilde{\boldsymbol{g}}_B)] $$

This objective is clearly intuitive: choose the learning rate that leads to the fastest descent of the loss function on average. To solve this problem, we approximately expand the loss function to the second order:

$$ \mathcal{L}(\boldsymbol{\theta} - \eta\tilde{\boldsymbol{g}}_B) \approx \mathcal{L}(\boldsymbol{\theta}) - \eta\tilde{\boldsymbol{g}}_B^{\top}\underbrace{\frac{\partial \mathcal{L}(\boldsymbol{\theta})}{\partial\boldsymbol{\theta}}}_{\text{is }\boldsymbol{g}} + \frac{1}{2}\eta^2 \tilde{\boldsymbol{g}}_B^{\top}\underbrace{\frac{\partial^2 \mathcal{L}(\boldsymbol{\theta})}{\partial\boldsymbol{\theta}^2}}_{\text{denoted as }\boldsymbol{H}}\tilde{\boldsymbol{g}}_B = \mathcal{L}(\boldsymbol{\theta}) - \eta\tilde{\boldsymbol{g}}_B^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2 \tilde{\boldsymbol{g}}_B^{\top}\boldsymbol{H}\tilde{\boldsymbol{g}}_B $$

Here, $\boldsymbol{H}$ is the Hessian matrix, and $\frac{\partial \mathcal{L}(\boldsymbol{\theta})}{\partial\boldsymbol{\theta}}$ is the gradient of the loss function. The ideal objective function is based on all samples, which is why its gradient is the mean of $\tilde{\boldsymbol{g}}_B$, which is $\boldsymbol{g}$. Next, we take the expectation:

$$ \mathbb{E}[\mathcal{L}(\boldsymbol{\theta} - \eta\tilde{\boldsymbol{g}}_B)] \approx \mathbb{E}[\mathcal{L}(\boldsymbol{\theta}) - \eta\tilde{\boldsymbol{g}}_B^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2 \tilde{\boldsymbol{g}}_B^{\top}\boldsymbol{H}\tilde{\boldsymbol{g}}_B] = \mathcal{L}(\boldsymbol{\theta}) - \eta\boldsymbol{g}^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2 \mathbb{E}[\tilde{\boldsymbol{g}}_B^{\top}\boldsymbol{H}\tilde{\boldsymbol{g}}_B] $$

The last term requires a slight trick:

$$ \begin{aligned}\mathbb{E}[\tilde{\boldsymbol{g}}_B^{\top}\boldsymbol{H}\tilde{\boldsymbol{g}}_B] =&\, \mathbb{E}[\text{Tr}(\tilde{\boldsymbol{g}}_B^{\top}\boldsymbol{H}\tilde{\boldsymbol{g}}_B)]= \mathbb{E}[\text{Tr}(\tilde{\boldsymbol{g}}_B\tilde{\boldsymbol{g}}_B^{\top}\boldsymbol{H})] = \text{Tr}(\mathbb{E}[\tilde{\boldsymbol{g}}_B\tilde{\boldsymbol{g}}_B^{\top}]\boldsymbol{H})\\[5pt] =&\, \text{Tr}((\boldsymbol{g}\boldsymbol{g}^{\top} + \boldsymbol{\Sigma}/B)\boldsymbol{H}) = \boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g} + \text{Tr}(\boldsymbol{\Sigma}\boldsymbol{H})/B\end{aligned} $$

The transformation mainly uses $\text{Tr}(\boldsymbol{A}\boldsymbol{B}) = \text{Tr}(\boldsymbol{B}\boldsymbol{A})$. Now, assuming the positive definiteness of $\boldsymbol{H}$, the problem becomes minimizing a quadratic function, which can be easily solved to yield:

$$ \eta^* \approx \frac{\boldsymbol{g}^{\top}\boldsymbol{g}}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g} + \text{Tr}(\boldsymbol{\Sigma}\boldsymbol{H})/B} = \frac{\eta_{\max}}{1 + \mathcal{B}_{\text{noise}}/B} \quad \text{(1)} $$

This results in the conclusion that “as $B$ increases, it monotonically increases with an upper bound,” where:

$$ \eta_{\max} = \frac{\boldsymbol{g}^{\top}\boldsymbol{g}}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}},\qquad\mathcal{B}_{\text{noise}} = \frac{\text{Tr}(\boldsymbol{\Sigma}\boldsymbol{H})}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}} $$

Practical Analysis
#

When $B \ll \mathcal{B}_{\text{noise}}$, we have $1 + \mathcal{B}_{\text{noise}}/B\approx \mathcal{B}_{\text{noise}}/B$, so $\eta^* \approx \eta_{\max}B/\mathcal{B}_{\text{noise}}\propto B$, which is linear scaling. This again demonstrates that linear scaling is merely a local approximation for small Batch Sizes. When $B > \mathcal{B}_{\text{noise}}$, $\eta^*$ gradually approaches the saturation value $\eta_{\max}$, meaning that the increase in training cost far outweighs the improvement in training efficiency. Therefore, $\mathcal{B}_{\text{noise}}$ acts as a watershed: once the Batch Size exceeds this value, there’s no need to continue investing computational power to increase the Batch Size.

For practical application, the most crucial problem is undoubtedly how to estimate $\eta_{\max}$ and $\mathcal{B}_{\text{noise}}$, especially since $\mathcal{B}_{\text{noise}}$ directly relates to the learning rate scaling law and the saturation of training efficiency. Direct computation of both involves the Hessian matrix $\boldsymbol{H}$, whose computational cost is proportional to the square of the number of parameters. In an era where hundreds of millions of parameters are considered small models, calculating the Hessian matrix is clearly impractical, so more effective calculation methods must be sought.

Let’s first look at $\mathcal{B}_{\text{noise}}$. Its formula is $\frac{\text{Tr}(\boldsymbol{\Sigma}\boldsymbol{H})}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}}$. Both the numerator and denominator have an $\boldsymbol{H}$, which undoubtedly gives us an urge to “cancel them out.” In fact, the simplification idea is exactly this: assume $\boldsymbol{H}$ is approximately a scalar multiple of the identity matrix, then we get:

$$ \mathcal{B}_{\text{noise}} = \frac{\text{Tr}(\boldsymbol{\Sigma}\boldsymbol{H})}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}}\approx \frac{\text{Tr}(\boldsymbol{\Sigma})}{\boldsymbol{g}^{\top}\boldsymbol{g}}\triangleq \mathcal{B}_{\text{simple}} $$

$\mathcal{B}_{\text{simple}}$ is more computationally feasible, and experiments show that it is often a good approximation of $\mathcal{B}_{\text{noise}}$. Therefore, we choose to estimate $\mathcal{B}_{\text{simple}}$ instead of $\mathcal{B}_{\text{noise}}$. Note that $\text{Tr}(\boldsymbol{\Sigma})$ only requires the diagonal elements, so there’s no need to calculate the full covariance matrix; we just need to calculate the variance for each gradient component individually and then sum them. In data parallel scenarios, the gradient variance can be estimated directly using the gradients calculated on each device.

It should be pointed out that equation (1) and other results are actually dynamic, meaning that in theory, $\eta_{\max}$, $\mathcal{B}_{\text{noise}}$, and $\mathcal{B}_{\text{simple}}$ are different at each training step. Therefore, if we want to obtain a static law, we need to train for a period until the model’s training enters a “normal track” before the calculated $\mathcal{B}_{\text{simple}}$ becomes reliable. Alternatively, we can continuously monitor $\mathcal{B}_{\text{simple}}$ during training to judge the gap between the current settings and the optimum.

As for $\eta_{\max}$, there is no need to estimate it directly from the formula. Instead, one can perform a grid search for the learning rate at a certain small Batch Size to find an approximate $\eta^*$, and then combine it with the estimated $\mathcal{B}_{\text{simple}}$ to deduce $\eta_{\max}$.

Data Efficiency
#

From the above results, we can also derive an asymptotic relationship concerning the amount of training data and the number of training steps. The derivation process is also simple: substituting equation (1) into the loss function, we can calculate that the reduction in the loss function per iteration at the optimal learning rate is:

$$ \Delta\mathcal{L} = \mathcal{L}(\boldsymbol{\theta}) - \mathbb{E}[\mathcal{L}(\boldsymbol{\theta} - \eta^*\tilde{\boldsymbol{g}}_B)] \approx \frac{\Delta\mathcal{L}_{\max}}{1 + \mathcal{B}_{\text{noise}}/B} \quad \text{(2)} $$

where $\Delta\mathcal{L}_{\max} = \frac{(\boldsymbol{g}^{\top}\boldsymbol{g})^2}{2\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}}$. The key focus now is the interpretation of this result.

When $B\to\infty$, i.e., full-batch SGD, the loss function reduction per step reaches its maximum $\Delta\mathcal{L}_{\max}$. At this point, the target can be reached with the fewest training steps (denoted as $S_{\min}$). When $B$ is finite, the average loss reduction per step is only $\Delta\mathcal{L}$, which means we need $1 + \mathcal{B}_{\text{noise}}/B$ steps to achieve the same reduction as a single step of full-batch SGD. Therefore, the total number of training steps is approximately $S = (1 + \mathcal{B}_{\text{noise}}/B)S_{\min}$.

Since the Batch Size is $B$, the total number of samples consumed during training is $E = BS = (B + \mathcal{B}_{\text{noise}})S_{\min}$, which is an increasing function of $B$. Also, as $B\to 0$, $E_{\min} = \mathcal{B}_{\text{noise}}S_{\min}$. This indicates that as long as we use a sufficiently small Batch Size to train the model, the total number of training samples $E$ required will decrease accordingly, at the cost of a very large number of training steps $S$. Furthermore, using these notations, we can write their relationship as:

$$ \left(\frac{S}{S_{\min}} - 1\right)\left(\frac{E}{E_{\min}} - 1\right) = 1 \quad \text{(3)} $$

This is the scaling law between the amount of training data and the number of training steps, indicating that the smaller the amount of data, the smaller the Batch Size should be, leading to more training steps, in order to have a better chance of reaching an optimal solution. The derivation here has been simplified by the author, assuming the invariance of $\mathcal{B}_{\text{noise}}$ and $\Delta\mathcal{L}_{\max}$ throughout the training process. If necessary, the dynamic changes could be handled more precisely using integration as in the original paper’s appendix (but would require introducing the assumption $B = \sqrt{r\mathcal{B}_{\text{noise}}}$); we won’t elaborate on that here.

Moreover, since $\mathcal{B}_{\text{noise}} = E_{\min}/S_{\min}$, equation (3) also provides another way to estimate $\mathcal{B}_{\text{noise}}$: by performing multiple experiments and a grid search to obtain several $(S,E)$ pairs, then fitting the equation to estimate $E_{\min},S_{\min}$, and subsequently calculating $\mathcal{B}_{\text{noise}}$.

Adaptive Version
#

It must be said that OpenAI is indeed one of the pioneers of various Scaling Laws. The aforementioned analysis is quite brilliant, and the results are also very rich. What’s more commendable is that the entire derivation process is not complicated, giving a sense of profound simplicity. However, the current conclusions are all derived based on SGD, and their applicability to adaptive learning rate optimizers like Adam is still unclear. This part of the content is completed by 《Surge Phenomenon in Optimal Learning Rate and Batch Size Scaling》.

Sign Approximation
#

The approach to analyzing Adam is the same as for SGD, both based on a second-order expansion. The difference is that the direction vector changes from $\tilde{\boldsymbol{g}}_B$ to a general vector $\tilde{\boldsymbol{u}}_B$. In this case, we have:

$$ \mathbb{E}[\mathcal{L}(\boldsymbol{\theta} - \eta\tilde{\boldsymbol{u}}_B)] \approx \mathcal{L}(\boldsymbol{\theta}) - \eta\mathbb{E}[\tilde{\boldsymbol{u}}_B]^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2 \text{Tr}(\mathbb{E}[\tilde{\boldsymbol{u}}_B\tilde{\boldsymbol{u}}_B^{\top}]\boldsymbol{H}) $$

Now we need to determine $\tilde{\boldsymbol{u}}_B$ and calculate the corresponding $\mathbb{E}[\tilde{\boldsymbol{u}}_B]$ and $\mathbb{E}[\tilde{\boldsymbol{u}}_B\tilde{\boldsymbol{u}}_B^{\top}]$. Since only an asymptotic relationship is needed, similar to the “Linear Scaling” section and 《Configuring Different Learning Rates, Can LoRA Improve Further?》, we choose SignSGD, i.e., $\tilde{\boldsymbol{u}}_B = \text{sign}(\tilde{\boldsymbol{g}}_B)$, as an approximation for Adam. This approach may have originated from 《Dissecting Adam: The Sign, Magnitude and Variance of Stochastic Gradients》. The rationality of this approximation is reflected in two points:

  1. Regardless of the values of $\beta_1, \beta_2$, Adam’s first update vector is always $\text{sign}(\tilde{\boldsymbol{g}}_B)$.

  2. When $\beta_1=\beta_2=0$, Adam’s update vector is always $\text{sign}(\tilde{\boldsymbol{g}}_B)$.

To compute $\mathbb{E}[\tilde{\boldsymbol{u}}_B]$ and $\mathbb{E}[\tilde{\boldsymbol{u}}_B\tilde{\boldsymbol{u}}_B^{\top}]$, we also need to assume, as in the “Linear Scaling” section, that $\tilde{\boldsymbol{g}}_B$ follows a distribution $\mathcal{N}(\boldsymbol{g},\boldsymbol{\Sigma}/B)$. To simplify calculations, we further assume that $\boldsymbol{\Sigma}$ is a diagonal matrix $\text{diag}(\sigma_1^2,\sigma_2^2,\sigma_3^2,\cdots)$, i.e., assuming the components are independent. In this way, we can handle each component independently. According to the reparameterization trick, we know that $\tilde{g}_B\sim \mathcal{N}(g, \sigma^2/B)$ is equivalent to $\tilde{g}_B=g + \sigma z/\sqrt{B},z\sim\mathcal{N}(0,1)$, thus:

$$ \begin{aligned}\mathbb{E}[\tilde{u}_B] =&\, \mathbb{E}[\text{sign}(g + \sigma z/\sqrt{B})] = \mathbb{E}[\text{sign}(g\sqrt{B}/\sigma + z)] \\[5pt] =&\, \frac{1}{\sqrt{2\pi}}\int_{-\infty}^{\infty} \text{sign}(g\sqrt{B}/\sigma + z) e^{-z^2/2}dz \\[5pt] =&\, \frac{1}{\sqrt{2\pi}}\int_{-\infty}^{-g\sqrt{B}/\sigma} (-1)\times e^{-z^2/2}dz + \frac{1}{\sqrt{2\pi}}\int_{-g\sqrt{B}/\sigma}^{\infty} 1\times e^{-z^2/2}dz \\[5pt] =&\, \text{erf}\left(\frac{g}{\sigma}\sqrt{\frac{B}{2}}\right)\end{aligned} $$

Here, $\text{erf}$ is the error function, which is an S-shaped function with a range of $(-1,1)$, similar to $\tanh$, and can be used as a smooth approximation of $\text{sign}$. However, $\text{erf}$ itself does not have an elementary function expression, so we’d better find an elementary function approximation to observe the changing pattern more intuitively. We discussed this topic before in 《How GELU’s Two Elementary Function Approximations Are Derived》, but those approximations were still too complex (involving exponential operations). Here, we’ll make a simpler one:

$$ \text{erf}(x)\approx \text{sign}(x) = \frac{x}{|x|} = \frac{x}{\sqrt{x^2}}\approx \frac{x}{\sqrt{x^2+c}} $$

We choose $c=\pi/4$ such that the first-order approximation of this approximation at $x=0$ is equal to the first-order approximation of $\text{erf}$. Of course, with so many levels of approximation already, the value of $c$ is not that important; we just need to know that such a $c > 0$ exists. Based on this approximation, we get:

$$ \mathbb{E}[\tilde{u}_B] \approx \frac{g/\sigma}{\sqrt{\pi/2B+(g/\sigma)^2}}\quad\Rightarrow\quad\mathbb{E}[\tilde{\boldsymbol{u}}_B]_i \approx \frac{g_i/\sigma_i}{\sqrt{\pi/2B+(g_i/\sigma_i)^2}}\triangleq \mu_i $$

It can be observed that a clear difference between Adam and SGD is that $\mathbb{E}[\tilde{\boldsymbol{u}}_B]$ already depends on $B$ at this step. Fortunately, the second moment is simpler now, because the square of $\text{sign}(x)$ must be 1, so:

$$ \mathbb{E}[\tilde{u}_B^2] = 1\quad\Rightarrow\quad\mathbb{E}[\tilde{\boldsymbol{u}}_B\tilde{\boldsymbol{u}}_B^{\top}]_{i,j} \to\left\{\begin{aligned}&=1, & i = j \\ & \approx\mu_i \mu_j,&\,i\neq j\end{aligned}\right. $$

Using these results, we can find:

$$ \begin{gather}\eta^* \approx \frac{\mathbb{E}[\tilde{\boldsymbol{u}}_B]^{\top}\boldsymbol{g}}{\text{Tr}(\mathbb{E}[\tilde{\boldsymbol{u}}_B\tilde{\boldsymbol{u}}_B^{\top}]\boldsymbol{H})} \approx \frac{\sum_i \mu_i g_i}{\sum_i H_{i,i} + \sum_{i\neq j} \mu_i \mu_j H_{i,j}} \quad \text{(4)} \\[5pt]\Delta \mathcal{L} = \mathcal{L}(\boldsymbol{\theta}) - \mathbb{E}[\mathcal{L}(\boldsymbol{\theta} - \eta^*\tilde{\boldsymbol{u}}_B)] \approx \frac{1}{2}\frac{(\sum_i \mu_i g_i)^2}{\sum_i H_{i,i} + \sum_{i\neq j} \mu_i \mu_j H_{i,j}} \quad \text{(5)}\end{gather} $$

Two Special Cases
#

Compared to equation (1) for SGD, equation (4) for Adam is more complex, making it difficult to intuitively grasp its dependence on $B$. So, we start with a few special cases.

First, consider $B\to\infty$. At this point, $\mu_i = \text{sign}(g_i)$, so:

$$ \eta^* \approx \frac{\sum_i |g_i|}{\sum_i H_{i,i} + \sum_{i\neq j} \text{sign}(g_i g_j) H_{i,j}} $$

Its difference from $\eta_{\max}$ of SGD is that it is not homogeneous with respect to the gradient, but rather proportional to the scale of the gradient.

Next, let’s consider the case where $\boldsymbol{H}$ is a diagonal matrix, i.e., $H_{i,j}=0$ when $i\neq j$. In this case:

$$ \eta^* \approx \frac{\sum_i \mu_i g_i}{\sum_i H_{i,i}}=\frac{1}{\sum_i H_{i,i}}\sum_i \frac{g_i^2/\sigma_i}{\sqrt{\pi/2B+(g_i/\sigma_i)^2}} $$

Each term in this sum monotonically increases with $B$ and has an upper bound, so the overall result also behaves similarly. To capture the most essential pattern, we can further simplify $\mu_i$ (this part differs from the original paper):

$$ \mu_i = \frac{g_i/\sigma_i}{\sqrt{\pi/2B+(g_i/\sigma_i)^2}} = \frac{\text{sign}(g_i)}{\sqrt{1 + \pi(\sigma_i/g_i)^2/2B}} \approx \frac{\text{sign}(g_i)}{\sqrt{1 + \pi\kappa^2/2B}} \quad \text{(6)} $$

Here, the assumption is that there exists a constant $\kappa^2$ independent of $i$ (for example, one could consider some kind of average of all $(\sigma_i/g_i)^2$; in fact, $\kappa^2$ here is similar to the $\mathcal{B}_{\text{simple}}$ from before, and can also be estimated according to the definition of $\mathcal{B}_{\text{simple}}$), such that replacing $(\sigma_i/g_i)^2$ with $\kappa^2$ is a good approximation for any $i$. Thus:

$$ \eta^* \approx \frac{\sum_i \mu_i g_i}{\sum_i H_{i,i}}\approx \frac{\sum_i |g_i|}{\sum_i H_{i,i}}\frac{1}{\sqrt{1 + \pi\kappa^2/2B}} \quad \text{(7)} $$

When $\pi\kappa^2\gg 2B$, i.e., $B \ll \pi\kappa^2/2$, it can be further approximated as:

$$ \eta^* \approx \frac{\sum_i \sigma_i}{\kappa\sum_i H_{i,i}}\sqrt{\frac{2B}{\pi}} \propto \sqrt{B} $$

This indicates that when the Batch Size itself is small, Adam indeed conforms to the square root scaling law.

Surge Phenomenon
#

If we apply approximation (6) to the original equation (4), we will find that it exhibits some completely new characteristics. Specifically, we have:

$$ \eta^* \approx \frac{\sum_i \mu_i g_i}{\sum_i H_{i,i} + \sum_{i\neq j} \mu_i \mu_j H_{i,j}} \approx \frac{\eta_{\max}}{\frac{1}{2}\left(\frac{\beta_{\text{noise}}}{\beta} + \frac{\beta}{\beta_{\text{noise}}}\right)} \quad \text{(8)} $$

where $\beta = (1 + \pi\kappa^2/2B)^{-1/2}$, and

$$ \beta_{\text{noise}} = \sqrt{\frac{\sum_i H_{i,i}}{\sum_{i\neq j}\text{sign}(g_i g_j) H_{i,j}}},\quad \eta_{\max} = \frac{\sum_i |g_i|}{2\sqrt{\left(\sum_i H_{i,i}\right)\left(\sum_{i\neq j} \text{sign}(g_i g_j) H_{i,j}\right)}} $$

Note that $\beta$ is a monotonically increasing function of $B$, but the final approximation in equation (8) is not a monotonically increasing function of $\beta$; it first increases and then decreases, reaching its maximum when $\beta=\beta_{\text{noise}}$. This implies that there exists a corresponding $\mathcal{B}_{\text{noise}}$ such that when the Batch Size exceeds this $\mathcal{B}_{\text{noise}}$, the optimal learning rate should decrease instead of increase! This is the “Surge phenomenon” mentioned in the title of the original paper. (Of course, there is also a constraint here: $\beta$ is always less than 1. If $\beta_{\text{noise}} \geq 1$, then the relationship between the optimal learning rate and Batch Size remains monotonically increasing.)

How can we intuitively understand the Surge phenomenon? The author believes that this is essentially a manifestation of the suboptimality of adaptive learning rate strategies. Taking the approximation $\tilde{\boldsymbol{u}}_B = \text{sign}(\tilde{\boldsymbol{g}}_B)$ as an example, the larger $B$ is, the more accurate $\tilde{\boldsymbol{g}}_B$ becomes; as $B\to \infty$, it becomes $\text{sign}(\boldsymbol{g})$. However, is $\text{sign}(\boldsymbol{g})$ the most scientific update direction? Not necessarily, especially in later stages of training, this adaptive strategy might even have negative effects. Therefore, when $B$ is appropriately chosen, the noise in $\text{sign}(\tilde{\boldsymbol{g}}_B)$ might actually correct this suboptimality, whereas when $B$ continues to increase, the noise decreases, thereby reducing the opportunity for correction, which in turn necessitates a more cautious reduction in the learning rate.

Efficiency Relationship
#

Similar to the SGD analysis, finally we can consider $\Delta\mathcal{L}$. Substituting equation (8) into equation (5), restoring the notation $B$ and simplifying (the simplification process requires no approximation) yields:

$$ \Delta \mathcal{L} \approx \frac{\Delta \mathcal{L}_{\max}}{1 + \mathcal{B}_{\text{noise-2}}/B} \quad \text{(9)} $$

where

$$ \Delta \mathcal{L}_{\max} = \frac{\beta_{\text{noise}}\eta_{\max}\sum_i|g_i|}{1 + \beta_{\text{noise}}^2},\quad \mathcal{B}_{\text{noise-2}} = \frac{\pi\kappa^2\beta_{\text{noise}}^2}{2(1 + \beta_{\text{noise}}^2)} \quad \text{(10)} $$

Note that $\mathcal{B}_{\text{noise-2}}$ is a new notation; it is not $\mathcal{B}_{\text{noise}}$. The latter is the theoretically optimal Batch Size derived by solving $\beta=\beta_{\text{noise}}$, and the result is:

$$ \mathcal{B}_{\text{noise}} = \frac{\pi\kappa^2\beta_{\text{noise}}^2}{2(1 - \beta_{\text{noise}}^2)} $$

Their relationship is:

$$ \frac{1}{\mathcal{B}_{\text{noise-2}}} - \frac{1}{\mathcal{B}_{\text{noise}}} = \frac{4}{\pi\kappa^2}\quad\Rightarrow\quad \mathcal{B}_{\text{noise}} = \left(\frac{1}{\mathcal{B}_{\text{noise-2}}} - \frac{4}{\pi\kappa^2}\right)^{-1} \quad \text{(11)} $$

Since equation (9) is formally identical to equation (2) for SGD, the analysis in that section also applies. Therefore, equation (3) can also be derived:

$$ \left(\frac{S}{S_{\min}} - 1\right)\left(\frac{E}{E_{\min}} - 1\right) = 1 $$

Except now $E_{\min}/S_{\min} = \mathcal{B}_{\text{noise-2}}$. In this way, we have a scheme for estimating $\beta_{\text{noise}}$ and $\mathcal{B}_{\text{noise}}$: by performing multiple experiments to obtain several $(S,E)$ pairs, and simultaneously estimating $\kappa^2$ during the experiment. Then, fitting the equation yields $E_{\min},S_{\min}$, which then allows for the estimation of $\mathcal{B}_{\text{noise-2}}$. Finally, $\beta_{\text{noise}}$ can be solved from equation (10).

If $\beta_{\text{noise}} \geq 1$, then no optimal $\mathcal{B}_{\text{noise}}$ exists. If $\beta_{\text{noise}} \gg 1$, it indicates that the diagonal elements of the Hessian matrix dominate, and in this case, scaling law (7) applies, meaning increasing Batch Size can always appropriately increase the learning rate. When $\beta_{\text{noise}} < 1$, the optimal $\mathcal{B}_{\text{noise}}$ can be solved from equation (11), and if the Batch Size exceeds this value, the learning rate should actually decrease.

Supplementary Notes
#

It should be pointed out that the starting points and final conclusions of the analyses in the preceding sections are largely similar to those in the original paper 《Surge Phenomenon in Optimal Learning Rate and Batch Size Scaling》, but the approximate treatments in the intermediate steps differ.

Most of the conclusions in the original paper are approximate results under the assumption $B \ll \pi(\sigma_i/g_i)^2/2$, which led to the conclusion that the Surge phenomenon almost always appears. This is actually not very scientific. The most obvious problem is the form of the assumption $B \ll \pi(\sigma_i/g_i)^2/2$ itself; its right-hand side is related to $i$. We cannot assign a separate Batch Size to each component, so to get a global result, it can only be $B \ll \min_i \pi(\sigma_i/g_i)^2/2$, which is perhaps too restrictive.

This article’s approach introduces approximation (6), which can be seen as a mean-field approximation, intuitively more reasonable than the pointwise assumption $B \ll \pi(\sigma_i/g_i)^2/2$. Therefore, in principle, the conclusions should be more precise, for example, leading to the conclusion that “even if the off-diagonal elements of the Hessian matrix are not negligible, the Surge phenomenon may not necessarily occur” (depending on $\beta_{\text{noise}}$). In particular, this precision does not sacrifice simplicity; for instance, equation (8) is also very concise and clear, and equation (9) has the same form as the original paper, without requiring additional approximate assumptions, and so on.

Finally, a slight sigh of relief: OpenAI’s analysis of SGD was already a work from 2018, while the paper on the Surge phenomenon was only published in mid-this year. It took 6 years from SGD to Adam, which is quite surprising. It’s largely due to OpenAI’s “prestige” and the guess (1) that made everyone think there wasn’t much left to do with Adam, but unexpectedly, Adam might have some new characteristics. Of course, the question of how reasonable and representative the approximation $\tilde{\boldsymbol{u}}_B = \text{sign}(\tilde{\boldsymbol{g}}_B)$ is for Adam in practical situations still deserves further thought, in the author’s opinion.

Summary (formatted)
#

This article discusses the classic “Batch Size and Learning Rate Scaling Law” problem from multiple perspectives, with a particular focus on OpenAI’s derivation and conclusions based on the second-order approximation of the loss function, as well as subsequent work that uses the same idea to analyze the Adam optimizer.

@online{kexuefm-10542,
        title={How Should the Learning Rate Change When Batch Size Increases?},
        author={苏剑林},
        year={2024},
        month={11},
        url={\url{https://kexue.fm/archives/10542}},
}