This is a gemini-2.5-flash
translation of a Chinese article.
It has NOT been vetted for errors. You should have the original article open in a parallel tab at all times.
As is well known, current Transformer models are getting bigger, but this “bigness” usually refers to “width” rather than “depth”. For example, although GPT-3 has hundreds of billions of parameters, it’s only a 96-layer Transformer model, far from the depth we can imagine. What limits the “deep” development of Transformers? Some readers might think it’s computational power, but a “wide and shallow” model doesn’t necessarily require much less computational power than a “narrow and deep” one. So, computational power is not the main limitation; ultimately, it’s due to the inherent training difficulties of Transformers. The general view is that the training difficulty of deep models stems from gradient vanishing or gradient explosion. However, practice shows that even with various methods to improve gradients, deep models are still not easy to train.
Recent works (such as Admin) point out that the fundamental difficulty in training deep models lies in “increment explosion,” meaning that the deeper the model, the greater the perturbation to the output. Last week’s paper, 《DeepNet: Scaling Transformers to 1,000 Layers》, follows this line of thought by conducting a scale analysis. Based on the analysis results, it adjusts the model’s normalization and initialization schemes, ultimately succeeding in training a 1000-layer Transformer model. The entire analysis process is quite valuable, so let’s learn from it.
Increment Explosion#
The original paper’s complete analysis is quite long, and some assumptions or descriptions are not entirely reasonable upon closer inspection. Therefore, in this article, I will try to correct these issues, attempting to derive similar results in a more reasonable way.
Assume the loss function is $\mathcal{L}(\boldsymbol{\theta})$, and $\boldsymbol{\theta}$ is its parameter. Consider the increment of the loss function when the parameter changes from $\boldsymbol{\theta}$ to $\boldsymbol{\theta}+\Delta\boldsymbol{\theta}$:
$$ \Delta\mathcal{L} = \mathcal{L}(\boldsymbol{\theta}+\Delta\boldsymbol{\theta}) - \mathcal{L}(\boldsymbol{\theta}) \approx \langle\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta}),\Delta\boldsymbol{\theta}\rangle $$For SGD, we have $\Delta\boldsymbol{\theta}=-\eta \nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})$, so $\Delta\mathcal{L} \approx -\eta\Vert\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})\Vert^2$. Suppose the model has $N$ layers, and each layer has $K$ parameter matrices ($K$ is approximately constant). With Xavier initialization and various Normalization methods, we can make the gradient magnitude of each parameter matrix $\mathcal{O}(1)$. Therefore, $\Delta\mathcal{L}=\mathcal{O}(\eta NK)$. This means that the update amount at each step of the model is proportional to the model depth $N$. If the model is deeper, the update amount is larger, which implies that the model is more likely to enter a less optimal local minimum in the initial stage, leading to training stagnation or even collapse. This is the “increment explosion” problem.
There are two solutions at this point: first, use a smaller learning rate (not exceeding $\eta/N$ magnitude) for training in the initial stage, and then gradually increase the learning rate; this is the Warmup trick. Second, adjust the initialization scheme so that the gradients of the parameters are of $\mathcal{O}(1/\sqrt{N})$ magnitude, which automatically offsets the effect of model depth.
Magnitude Analysis#
How can we achieve the second solution? We can try to analyze the gradients of the Transformer. However, exact gradient calculation is quite cumbersome, and in fact, we don’t need exact gradients; we only need to perform a magnitude analysis of the gradients. Therefore, we can use the following “magnitude decomposition” trick to transform it into a scalar derivative problem.
For a matrix $\boldsymbol{W}$, we decompose it into the form $\boldsymbol{W}=\lambda \boldsymbol{U}$, where
$$ \lambda = \mathop{\text{argmin}}_{\kappa > 0} \Vert \boldsymbol{W}\boldsymbol{W}^{\top}/\kappa^2 - \boldsymbol{I}\Vert,\quad $$In plain terms, we want to decompose a matrix into the product of a scalar $\lambda$ and a matrix $\boldsymbol{U}$ that is as orthogonal as possible. Since $\boldsymbol{U}$ is close to an orthogonal matrix, it acts as a standard reference frame, and the corresponding $\lambda$ represents the magnitude of the matrix $\boldsymbol{W}$. If $\boldsymbol{W}$ uses Xavier initialization, then $\lambda$ is equivalent to the gain parameter in it, meaning that it multiplies an additional $\lambda$ on top of Xavier initialization. This is because the result of Xavier initialization is already close to an orthogonal matrix, as discussed in 《A Geometric Perspective on Model Parameter Initialization Strategies》.
Under this decomposition, we have
$$ \frac{\partial \mathcal{L}(\lambda \boldsymbol{U})}{\partial \lambda} = \left\langle\frac{\partial \mathcal{L}(\lambda \boldsymbol{U})}{\partial (\lambda \boldsymbol{U})}, \boldsymbol{U}\right\rangle = \left\langle\frac{\partial \mathcal{L}(\boldsymbol{W})}{\partial \boldsymbol{W}}, \boldsymbol{U}\right\rangle $$This means that $\frac{\partial \mathcal{L}}{\partial \lambda}$ is proportionally related in magnitude to $\frac{\partial \mathcal{L}}{\partial \boldsymbol{W}}$. Therefore, performing a magnitude analysis on $\frac{\partial \mathcal{L}}{\partial \lambda}$ is equivalent to performing a magnitude analysis on $\frac{\partial \mathcal{L}}{\partial \boldsymbol{W}}$. This way, $\frac{\partial \mathcal{L}}{\partial \lambda}$ serves as a simple “probe” for the magnitude of $\frac{\partial \mathcal{L}}{\partial \boldsymbol{W}}$, and the original matrix derivative can be converted into a scalar derivative, reducing the complexity of the analysis.
Feedforward Gradients#
Many experimental results show that although Pre-Norm is easier to train than Post-Norm, Post-Norm often yields better final results. Therefore, the original paper retained the Post-Norm structure and considered a more general form (DeepNorm):
$$ \boldsymbol{x}_{l+1} = \text{LN}(\alpha\boldsymbol{x}_l + F(\boldsymbol{x}_l)) = \text{LN}(\boldsymbol{x}_l + F(\boldsymbol{x}_l)/\alpha) $$where $\alpha > 0$ is a constant. For simplicity, let’s first consider the FFN layer, in which case
$$ \boldsymbol{x}_{l+1} = \text{LN}(\underbrace{\boldsymbol{x}_l + \phi(\boldsymbol{x}_l \boldsymbol{W}_1)\boldsymbol{W}_2/\alpha}_{\text{denoted as }\boldsymbol{z}_{l+1}}) $$Here, $\phi$ is an activation function, typically ReLU or its variants (Swish, GeLU, etc.), which (approximately) satisfy $\phi(\lambda x) = \lambda \phi(x),\forall \lambda > 0$. Using the magnitude decomposition probe from the previous section, we get
$$ \boldsymbol{x}_{l+1} = \text{LN}(\underbrace{\boldsymbol{x}_l + \lambda_1 \lambda_2 \phi(\boldsymbol{x}_l \boldsymbol{U}_1)\boldsymbol{U}_2/\alpha}_{\text{denoted as }\boldsymbol{z}_{l+1}}) $$To find the gradients with respect to $\lambda$:
$$ \begin{aligned} \frac{\partial \mathcal{L}}{\partial \lambda_1} = \frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}\frac{\partial \boldsymbol{z}_{l+1}}{\partial \lambda_1} = \frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}\frac{\lambda_2 \phi(\boldsymbol{x}_l \boldsymbol{U}_1)\boldsymbol{U}_2}{\alpha} \\ \frac{\partial \mathcal{L}}{\partial \lambda_2} = \frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}\frac{\partial \boldsymbol{z}_{l+1}}{\partial \lambda_2} = \frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}\frac{\lambda_1 \phi(\boldsymbol{x}_l \boldsymbol{U}_1)\boldsymbol{U}_2}{\alpha} \end{aligned} $$We assert that $\frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}$ and $\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}$ are both $\mathcal{O}(1)$. Furthermore, since $\boldsymbol{U}_1$ and $\boldsymbol{U}_2$ are both close to orthogonal matrices, $\phi(\boldsymbol{x}_l \boldsymbol{U}_1)\boldsymbol{U}_2$ is also $\mathcal{O}(1)$. Therefore, we finally have
$$ \frac{\partial \mathcal{L}}{\partial \lambda_1} = \mathcal{O}\left(\frac{\lambda_2}{\alpha}\right),\quad \frac{\partial \mathcal{L}}{\partial \lambda_2} = \mathcal{O}\left(\frac{\lambda_1}{\alpha}\right) $$Self-Attention#
Now consider Self-Attention. For magnitude analysis, we can consider a single-head attention, which takes the form
$$ \boldsymbol{x}_{l+1} = \text{LN}(\boldsymbol{x}_l + \sigma(\boldsymbol{x}_l \boldsymbol{W}_q\boldsymbol{W}_k^{\top}\boldsymbol{x}_l^{\top})\boldsymbol{x}_l\boldsymbol{W}_v\boldsymbol{W}_o/\alpha) $$where $\sigma(\cdot)$ is shorthand for the softmax operation. Here, the Attention scale operation is omitted. After magnitude decomposition, the above equation becomes
$$ \boldsymbol{x}_{l+1} = \text{LN}(\underbrace{\boldsymbol{x}_l + \lambda_v\lambda_o \sigma (\lambda_q\lambda_k\boldsymbol{x}_l \boldsymbol{U}_q\boldsymbol{U}_k^{\top}\boldsymbol{x}_l^{\top})\boldsymbol{x}_l\boldsymbol{U}_v\boldsymbol{U}_o/\alpha}_{\text{denoted as }\boldsymbol{z}_{l+1}}) $$Now we can calculate the gradients for each $\lambda$. Due to the presence of softmax, the gradients of $\lambda_q, \lambda_k$ themselves will be very small and will not significantly affect the final update amount. So, it’s sufficient to consider the update amounts of $\lambda_v, \lambda_o$:
$$ \begin{aligned} \frac{\partial \mathcal{L}}{\partial \lambda_v} = \frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}\frac{\partial \boldsymbol{z}_{l+1}}{\partial \lambda_v} = \frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}\frac{\lambda_o \sigma (\lambda_q\lambda_k\boldsymbol{x}_l \boldsymbol{U}_q\boldsymbol{U}_k^{\top}\boldsymbol{x}_l^{\top})\boldsymbol{x}_l\boldsymbol{U}_v\boldsymbol{U}_o}{\alpha} \\ \frac{\partial \mathcal{L}}{\partial \lambda_o} = \frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}\frac{\partial \boldsymbol{z}_{l+1}}{\partial \lambda_o} = \frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}\frac{\lambda_v \sigma (\lambda_q\lambda_k\boldsymbol{x}_l \boldsymbol{U}_q\boldsymbol{U}_k^{\top}\boldsymbol{x}_l^{\top})\boldsymbol{x}_l\boldsymbol{U}_v\boldsymbol{U}_o}{\alpha} \end{aligned} $$Similarly, we assert that $\frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}$ and $\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}$ are both $\mathcal{O}(1)$. And note that softmax outputs a probability distribution, which then performs a weighted average of the tokens of $\boldsymbol{x}_l$. Generally, the vectors before and after averaging will be of the same order of magnitude, so we consider $\sigma (\lambda_q\lambda_k\boldsymbol{x}_l \boldsymbol{U}_q\boldsymbol{U}_k^{\top}\boldsymbol{x}_l^{\top})\boldsymbol{x}_l\boldsymbol{U}_v\boldsymbol{U}_o$ to also be $\mathcal{O}(1)$. Therefore, the result is similar to that of the FFN layer:
$$ \frac{\partial \mathcal{L}}{\partial \lambda_v} = \mathcal{O}\left(\frac{\lambda_o}{\alpha}\right),\quad \frac{\partial \mathcal{L}}{\partial \lambda_o} = \mathcal{O}\left(\frac{\lambda_v}{\alpha}\right) $$Preliminary Conclusion#
Now, for both FFN and Self-Attention, we have reached similar conclusions. For simplicity, let’s assume that the magnitude of each parameter (at least during initialization) is consistent, i.e., all $\lambda$ values are the same. Then the overall conclusion is
$$ \frac{\partial \mathcal{L}}{\partial \lambda} = \mathcal{O}\left(\frac{\lambda}{\alpha}\right) $$That is, the magnitude of the gradient is $\mathcal{O}(\lambda/\alpha)$. On the other hand, a Transformer model with $N$ layers generally means $N$ Self-Attention layers plus $N$ FFN layers, so strictly speaking, the number of layers is $2N$. Therefore, according to the analysis in the “Increment Explosion” section, we need to adjust the gradient to $\mathcal{O}(1/\sqrt{2N})$. The above equation tells us that this can be achieved by setting $\lambda/\alpha=1/\sqrt{2N}$. The scaling in the original paper is looser, yielding $\lambda/\alpha = 1/\sqrt{4N}$, which is equivalent in terms of magnitude.
Now we have a proportional relationship between $\lambda$ and $\alpha$, but we cannot directly obtain their specific values. According to the paper, starting from a symmetry perspective, setting $\lambda=1/\alpha$, we can solve for
$$ \alpha = (2N)^{1/4},\quad \lambda = (2N)^{-1/4} $$However, a purely symmetric explanation is clearly not convincing enough. We need to understand what different choices actually lead to. To this end, we can compare two other sets of solutions:
Alternative Solution 1: $\alpha=1,\lambda=(2N)^{-1/2}$. In this case, the parameter initialization is scaled down to $(2N)^{-1/2}$ of its original value, and the gradient is also scaled down to $(2N)^{-1/2}$ of its original value. According to SGD’s $\Delta\boldsymbol{\theta}=-\eta \nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})$, the update amount at each step is also $(2N)^{-1/2}$ of its original value. This means that the relative learning rate before and after adjustment does not change. Thus, it’s possible that initially $\lambda$ is at the $\mathcal{O}((2N)^{-1/2})$ level, but after a few steps on the training set, it might deviate from this magnitude.
Alternative Solution 2: $\alpha=(2N)^{1/2},\lambda=1$. In this case, the parameter initialization is not scaled down, but the gradient is scaled down to $(2N)^{-1/2}$ of its original value. According to SGD’s $\Delta\boldsymbol{\theta}=-\eta \nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})$, the update amount at each step is also $(2N)^{-1/2}$ of its original value. The relative learning rate before and after adjustment is significantly reduced, which could lead to very slow learning.
Both of these situations seem to have their drawbacks. Therefore, the solution $\alpha = (2N)^{1/4},\quad \lambda = (2N)^{-1/4}$, which is somewhere in between, seems plausible. It maintains the gradient scaled to $(2N)^{-1/2}$ of its original value while making the initial learning pace slightly slower, but not too slow, implicitly acting as a Warmup.
Multiple Optimizers#
The derivations above were all based on SGD, but in fact, we rarely train NLP models directly with SGD. We more often use adaptive learning rate optimizers, which mainly fall into two categories: one uses the second moment to correct the learning rate, like Adam, AdamW, etc.; the other further corrects the learning rate through the parameter’s magnitude, such as LAMB and AdaFactor. The original paper states: “We derive on SGD and then verify that it also works reasonably well on Adam,” but theoretically, they are not completely universal. In this section, we will conduct a targeted analysis.
For Adam-type optimizers, the update amount at each step is approximately $\Delta\boldsymbol{\theta}=-\eta\,\text{sign}(\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta}))$, so $\Delta\mathcal{L} \approx -\eta\Vert\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})\Vert_1$. This is proportional to the gradient’s 1st power, not 2nd. Therefore, to make the update amount independent of the number of layers, the gradient should be scaled down to $1/(2N)$ of its original value, i.e., we should have $\lambda/\alpha=1/(2N)$. If we similarly set $\lambda=1/\alpha$, then we get
$$ \alpha = (2N)^{1/2},\quad \lambda = (2N)^{-1/2} $$For LAMB-type optimizers, the update amount at each step is approximately $\Delta\boldsymbol{\theta}=-\eta\Vert\boldsymbol{\theta}\Vert\,\text{sign}(\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta}))$, so $\Delta\mathcal{L} \approx -\eta\Vert\boldsymbol{\theta}\Vert\Vert\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})\Vert_1$. Note that the scaling factor for parameters is $\lambda$ and for gradients is $\lambda/\alpha$, so $\Delta\mathcal{L}=\mathcal{O}(2N\lambda^2/\alpha)$, which implies $\lambda^2/\alpha=1/(2N)$. Note that for these types of optimizers, the relative update amount at each step is the same (equal to the learning rate $\eta$), and it will not change regardless of how $\alpha, \lambda$ are adjusted. Therefore, we can directly choose $\alpha=1,\lambda=(2N)^{-1/2}$.
The summary of results is as follows:
Optimizer | $\Delta\boldsymbol{\theta}$ | $\Delta\mathcal{L}$ | $\alpha$ | $\lambda$ |
---|---|---|---|---|
SGD | $-\eta \nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})$ | $-\eta\Vert\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})\Vert^2$ | $(2N)^{1/4}$ | $(2N)^{-1/4}$ |
Adam | $-\eta\,\text{sign}(\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta}))$ | $-\eta\Vert\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})\Vert_1$ | $(2N)^{1/2}$ | $(2N)^{-1/2}$ |
LAMB | $-\eta\Vert\boldsymbol{\theta}\Vert\,\text{sign}(\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta}))$ | $-\eta\Vert\boldsymbol{\theta}\Vert\Vert\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta})\Vert_1$ | $1$ | $(2N)^{-1/2}$ |
Post-hoc Analysis#
The previous two sections’ derivations both relied on the assertion that “$\frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}$ and $\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}$ are both $\mathcal{O}(1)$”. Is this assertion valid? Let’s analyze it post-hoc.
It’s actually quite simple. After the adjustments mentioned, whether for the FFN layer or the Self-Attention layer, the weight of each residual branch is scaled down to $\lambda^2/\alpha$ of its original value in the initial stage. Regardless of the optimizer used, $\lambda^2/\alpha$ is a relatively small number, meaning that the entire model is close to an identity function in the initial stage. Therefore, $\frac{\partial \mathcal{L}}{\partial \boldsymbol{x}_{l+1}}$ and $\frac{\partial \boldsymbol{x}_{l+1}}{\partial \boldsymbol{z}_{l+1}}$ are naturally $\mathcal{O}(1)$, so the conclusion and the assertion are self-consistent.
Additionally, some readers might wonder if the same analysis can be applied to Pre-Norm structures. The answer is yes, and the conclusions are largely consistent. However, since Norm is placed before the residual branch, there is no need to set the $\alpha$ parameter. Therefore, the conclusion is that all $\alpha$ values in the results for Post-Norm are equal to 1, and the corresponding $\lambda$ values are recalculated.
Finally, readers might question whether model depth is truly that important, given the extensive effort spent discussing how to make models deeper. Yes, it is. The original paper presents a beautiful experimental result: a 200-layer “deep and narrow” model (3.2 billion parameters) outperformed a previous 48-layer “shallow and wide” SOTA model (12 billion parameters):
Summary (formatted)#
This article analyzed the bottleneck in making Transformers “deep” and provided corresponding solutions. The main ideas in this article stem from Microsoft’s new DeepNet, and the analysis process of the original paper has been simplified and improved upon.
@online{kexuefm-8978,
title={What's so Difficult About Training a 1000-Layer Transformer?},
author={苏剑林},
year={2022},
month={03},
url={\url{https://kexue.fm/archives/8978}},
}