This is a gemini-2.5-flash
translation of a Chinese article.
It has NOT been vetted for errors. You should have the original article open in a parallel tab at all times.
By Su Jianlin | 2021-08-17 | 242693 Readers
A few days ago, when training a new Transformer model, I found that it wouldn’t converge no matter how I trained it. After some debugging, I discovered that I had forgotten to divide by $\boldsymbol{\sqrt{d}}$ after $\boldsymbol{Q}\boldsymbol{K}^{\top}$ when performing Self Attention. This prompted me to revisit why dividing by $\sqrt{d}$ is so important. Of course, Google’s T5 does not divide by $\sqrt{d}$, yet it can still converge normally, because it made some adjustments to its initialization strategy. So, this matter is also related to initialization.
Taking this opportunity, this article will sort out topics such as model initialization, parameterization, and normalization, with the discussion primarily centered around Transformer models.
Sampling Distributions#
Initialization naturally involves random sampling, so here we first introduce common sampling distributions. Generally, we sample from random distributions with specified means and variances for initialization. There are three commonly used random distributions: Normal distribution, Uniform distribution, and Truncated Normal distribution.
Normal and Uniform distributions are, of course, very common. The Normal distribution is usually denoted as $\mathcal{N}(\mu,\sigma^2)$, where $\mu$ is the mean and $\sigma^2$ is the variance. The Uniform distribution on the interval $[a,b]$ is generally denoted as $U[a,b]$, with a mean of $\frac{a+b}{2}$ and a variance of $\frac{(b-a)^2}{12}$. Therefore, if a mean $\mu$ and variance $\sigma^2$ are specified, the corresponding Uniform distribution is $U[\mu-\sqrt{3}\sigma,\mu + \sqrt{3}\sigma]$.
Generally speaking, Normal distribution sampling results in greater diversity, but it is theoretically unbounded, and sampling results with excessively large absolute values might be detrimental to optimization. Conversely, the Uniform distribution is bounded, but its sampling results are usually less varied. This led to the creation of the “Truncated Normal distribution”, which combines the advantages of both. A Truncated Normal distribution specifies both a mean $\mu$ and a variance $\sigma^2$, as well as an interval $[a,b]$. It samples from $\mathcal{N}(\mu,\sigma^2)$, and if the sampled result falls within $[a,b]$, it is retained; otherwise, sampling is repeated until a result falls within $[a,b]$.
In TensorFlow’s built-in tf.random.truncated_normal
, $a=\mu-2\sigma$ and $b=\mu+2\sigma$ are hardcoded. Thus, based on the formula, the actual mean of the function’s sampling results is still $\mu$, but the actual variance is $\gamma\sigma^2$, where:
To obtain sampling results with a variance of $\sigma^2$, the standard deviation passed to the function should be $\frac{\sigma}{\sqrt{\gamma}}=1.1368472\dots\sigma$.
Stable Second Moment#
In my previous article 《Understanding Model Parameter Initialization Strategies from a Geometric Perspective》, I analyzed existing initialization methods from a geometric perspective. The general idea is that certain random matrices approximate orthogonal matrices, thereby ensuring model stability during the initial stages. While the geometric perspective offers intuitive advantages, it is usually difficult to generalize. Therefore, we will now understand initialization methods from an algebraic perspective.
In general tutorials, the idea behind deriving initialization methods is to try to make inputs and outputs have the same mean and variance. It is usually assumed that the input is a random vector with a mean of 0 and a variance of 1, and then an attempt is made to make the output’s mean 0 and variance 1. However, I believe this is unnecessary, and for some non-negative activation functions, a mean of 0 is simply impossible. In fact, we only need a metric to gauge whether a certain indicator “vanishes” or “explodes.” A 0 mean and 1 variance are not strictly necessary. Here, we use the second (raw) moment as a substitute, which can be seen as a variant of the L2 norm. Similar to variance, it can be used to measure whether an indicator “vanishes” or “explodes,” but it is relatively more general and simpler.
Now, let’s consider a fully connected layer without an activation function (assuming $m$ input nodes and $n$ output nodes):
$$ y_j = b_j + \sum_i x_i w_{i,j} $$For simplicity, we usually initialize the bias term $b_j$ to all zeros and set the mean of $w_{i,j}$, $\mathbb{E}[w_{i,j}]$, to 0. This helps simplify the following results, but it’s not strictly necessary; it’s just a relatively clear choice. We calculate the second moment:
$$ \begin{aligned} \mathbb{E}[y_j^2] =&\, \mathbb{E}\left[\left(\sum_i x_i w_{i,j}\right)^2\right]=\mathbb{E}\left[\left(\sum_{i_1} x_{i_1} w_{i_1,j}\right)\left(\sum_{i_2} x_{i_2} w_{i_2,j}\right)\right]\\ =&\, \mathbb{E}\left[\sum_{i_1, i_2} (x_{i_1}x_{i_2}) (w_{i_1,j} w_{i_2,j})\right] = \sum_{i_1, i_2} \mathbb{E}[x_{i_1}x_{i_2}] \mathbb{E}[w_{i_1,j} w_{i_2,j}] \end{aligned} $$Note that $w_{i_1,j}$ and $w_{i_2,j}$ are independent and identically distributed, so when $i_1\neq i_2$, $\mathbb{E}[w_{i_1,j}w_{i_2,j}]=\mathbb{E}[w_{i_1,j}]\mathbb{E}[w_{i_2,j}]=0$. Therefore, we only need to consider the case where $i_1=i_2=i$. Assuming the second moment of the input is 1, then
$$ \mathbb{E}[y_j^2] = \sum_{i} \mathbb{E}[x_i^2] \mathbb{E}[w_{i,j}^2]= m\mathbb{E}[w_{i,j}^2] $$So, to make $\mathbb{E}[y_j^2]$ equal to 1, we need $\mathbb{E}[w_{i,j}^2]=1/m$. Combining this with the assumption of a mean of 0, we get the initialization strategy for $w_{i,j}$: “independently and repeatedly sample from a random distribution with a mean of 0 and a variance of $1/m$.” This is Lecun initialization. Note that in this process, we made no assumptions about the input’s mean, so it works even if the inputs are all non-negative.
Activation Functions#
Of course, this is only for scenarios without activation functions. If activation functions are considered, specific analysis is required for each case. For example, if the activation function is $\text{relu}$, we can assume that roughly half of the $y_j$ values are set to zero, so the estimated second moment result is half of the result in the previous formula:
$$ \mathbb{E}[y_j^2] = \frac{m}{2}\mathbb{E}[w_{i,j}^2] $$Thus, the initialization variance that keeps the second moment unchanged is $2/m$. This is He initialization, specifically for $\text{relu}$ networks.
However, if the activation functions are $\text{elu},\text{gelu}$, etc., the analysis becomes less straightforward. And if the activation functions are $\tanh,\text{sigmoid}$, it’s impossible to find any initialization that makes the second moment 1. In such cases, if we still want to maintain a constant second moment, a possible solution is to “fine-tune the definition of the activation function.”
Taking $\text{sigmoid}$ as an example, assume the input has a mean of 0 and a variance of 1. If we still use “mean of 0, variance of $1/m$” for initialization, then the output before activation also has a mean of 0 and a variance of 1. We can then estimate the second moment after $\text{sigmoid}$ using a standard normal distribution:
$$ \int_{-\infty}^{\infty} \frac{e^{-x^2/2}}{\sqrt{2\pi}}\text{sigmoid}(x)^2 dx = 0.2933790\dots $$This means that under this assumption, the second moment of the model’s activated output is roughly $0.293379$. Therefore, if we want to keep the output’s second moment roughly constant, we can divide the output by $\sqrt{0.293379}$. In other words, the activation function changes from $\text{sigmoid}(x)$ to $\frac{\text{sigmoid}(x)}{\sqrt{0.293379}}$. This is the “fine-tuned” activation function. If you deem it necessary, you can also change the output mean to 0 by subtracting a constant.
I remember in 2017, a “sensational” paper titled 《Self-Normalizing Neural Networks》 proposed an activation function called $\text{selu}$. It is essentially a “fine-tuned” $\text{elu}$ function based on the same idea, and its form is as follows:
$$ \text{selu}(x)=\lambda\left\{\begin{aligned} &x,& (x > 0) \\ &\alpha e^{x}-\alpha, &(x\leq 0) \end{aligned}\right. $$where $\lambda=1.0507\dots,\alpha=1.6732\dots$. It was “sensational for a while” partly because it claimed to achieve automatic network standardization without Batch Normalization or similar methods, and partly because it was accompanied by dozens of pages of mathematical derivations that were quite “intimidating.” However, from the perspective above, it simply introduces two parameters to fine-tune the $\text{elu}$ function, ensuring that when a standard normal distribution is the input, the mean of the activated output values is 0 and the variance is 1. So, at most, it’s just a relatively good initialization. That’s why it was only sensational for “a while.” For its two parameters, we can also use Mathematica for numerical solution:
f[x_] = Exp[-x^2/2]/Sqrt[2 Pi];
s[x_] = Piecewise[{{\[Lambda]*x,
x > 0}, {\[Lambda]*\[Alpha]*(Exp[x] - 1), x <= 0}}];
x1 = Integrate[f[x]*s[x], {x, -Infinity, Infinity}];
x2 = Integrate[f[x]*s[x]^2, {x, -Infinity, Infinity}];
N[Solve[{x1 == 0, x2 == 1}, {\[Lambda], \[Alpha]}], 20]
Direct Standardization#
Of course, compared to such simple “fine-tuning,” a more direct approach is various Normalization methods, such as Batch Normalization, Instance Normalization, Layer Normalization, etc. These methods directly calculate the mean and variance of the current data to standardize the output, without prior estimation of integrals. Sometimes, we also refer to them as “normalization.” These three standardization methods are largely similar; besides Batch Normalization’s extra step of using moving average for predicting mean and variance, they only differ in the dimension of standardization. For example, Layer Normalization, which is widely used in NLP, especially in Transformer models, is:
$$ y_{i,j,k} = \frac{x_{i,j,k} - \mu_{i,j}}{\sqrt{\sigma_{i,j}^2 + \epsilon}}\times\gamma_k + \beta_k,\quad \mu_{i,j} = \frac{1}{d}\sum_{k=1}^d x_{i,j,k},\quad \sigma_{i,j}^2 = \frac{1}{d}\sum_{k=1}^d (x_{i,j,k}-\mu_{i,j})^2 $$Other details will not be repeated here. For readers interested in the principle behind these methods, please refer to my previous article 《What Does BN Actually Do? A Self-Contrived Analysis》.
Here, I noticed an interesting phenomenon: Normalization typically includes two parts: mean subtraction (center) and division by standard deviation (scale). However, some recent works have gradually tried to remove the center step, and some even show that performance slightly improves after removing it.
For example, the 2019 paper 《Root Mean Square Layer Normalization》 compared Layer Normalization with the center step removed. The paper calls it RMS Norm, and its form is as follows:
$$ y_{i,j,k} = \frac{x_{i,j,k}}{\sqrt{\sigma_{i,j}^2 + \epsilon}}\times\gamma_k,\quad \sigma_{i,j}^2 = \frac{1}{d}\sum_{k=1}^d x_{i,j,k}^2 $$As can be seen, RMS Norm is just a simple variant of L2 Normalization. However, the overall results in this paper show that: RMS Norm is faster than Layer Normalization, and its effects are largely consistent.
In addition to this paper, RMS Norm was also used by Google in T5, and in another paper, 《Do Transformer Modifications Transfer Across Implementations and Applications?》, more comprehensive comparative experiments were conducted, demonstrating the superiority of RMS Norm. It seems likely that RMS Norm will replace Layer Normalization as the standard for Transformers in the future.
Coincidentally, also in 2019, the paper 《Analyzing and Improving the Image Quality of StyleGAN》 proposed StyleGAN2, an improved version of StyleGAN. It was found that the Instance Normalization used led to “water droplets” in some generated images. They eventually removed Instance Normalization and replaced it with something called “Weight demodulation,” but they also found that simply removing the center operation while retaining Instance Normalization could improve this phenomenon. This provides further evidence that the center operation in Normalization might lead to negative effects.
An intuitive guess is that the center operation, similar to the bias term in a fully connected layer, stores prior distribution information about the pre-training task. Storing this prior distribution information directly in the model might, paradoxically, reduce the model’s transferability. Therefore, T5 not only removed the center operation for Layer Normalization but also removed the bias terms for every layer.
NTK Parameterization#
Returning to Xavier initialization for fully connected layers, it states that we should initialize with a “random distribution with a mean of 0 and a variance of $1/m$.” However, besides directly using this initialization method, we can also use another parameterization approach: initialize with a “random distribution with a mean of 0 and a variance of 1,” but then divide the output by $\sqrt{m}$. That is, the model becomes:
$$ y_j = b_j + \frac{1}{\sqrt{m}}\sum_i x_i w_{i,j} $$This is known as “NTK parameterization” in Gaussian processes. Relevant papers include 《Neural Tangent Kernel: Convergence and Generalization in Neural Networks》 and 《On the infinite width limit of neural networks with a standard parameterization》. However, for me, I first encountered this operation in the PGGAN paper 《Progressive Growing of GANs for Improved Quality, Stability, and Variation》.
Evidently, using NTK parameterization, we can initialize all parameters with a variance of 1 while still maintaining the same second moment. Even the “fine-tuning activation functions” introduced earlier can be seen as a form of NTK parameterization. A natural question arises: What are the advantages of NTK parameterization compared to directly using Xavier initialization?
Theoretically, there’s a slight advantage. With NTK parameterization, all parameters can be initialized with a distribution of variance 1, meaning that the magnitude of each parameter is roughly the same, at an $\mathcal{O}(1)$ level. This allows us to set larger learning rates, such as $10^{-2}$. If an adaptive optimizer is used, the update amount is approximately $\frac{\text{gradient}}{\sqrt{\text{gradient}\otimes\text{gradient}}}\times\text{learning rate}$, so we know that a $10^{-2}$ learning rate adjusts parameters by roughly $1\%$ at each step. In summary, NTK parameterization allows us to treat every parameter more equally and provides a clearer understanding of the training update magnitude, enabling better parameter tuning.
Speaking of which, we can now discuss the problem at the beginning of this article: Why is dividing by $\sqrt{d}$ so important in Attention? For two $d$-dimensional vectors $\boldsymbol{q}$ and $\boldsymbol{k}$, assuming they are both sampled from a “mean of 0, variance of 1” distribution, the second moment of their dot product is:
$$ \begin{aligned} \mathbb{E}[(\boldsymbol{q}\cdot \boldsymbol{k})^2]=&\, \mathbb{E}\left[\left(\sum_{i=1}^d q_i k_i\right)^2\right] = \mathbb{E}\left[\left(\sum_i q_i k_i\right)\left(\sum_j q_j k_j\right)\right]\\ =&\, \mathbb{E}\left[\sum_{i,j} (q_i q_j) (k_i k_j)\right] = \sum_{i,j} \mathbb{E}[q_i q_j] \mathbb{E}[k_i k_j]\\ =&\, \sum_i \mathbb{E}[q_i^2] \mathbb{E}[k_i^2] = d \end{aligned} $$That is, the second moment of the dot product is $d$. Since the mean is also 0, this means the variance is also $d$. Attention involves a dot product followed by softmax. The main operation is $e^{\boldsymbol{q}\cdot \boldsymbol{k}}$. We can roughly estimate that the values after the dot product and before softmax are in the range of $-3\sqrt{d}$ to $3\sqrt{d}$. Since $d$ is usually at least 64, $e^{3\sqrt{d}}$ is relatively large, and $e^{-3\sqrt{d}}$ is relatively small. Therefore, after softmax, the Attention distribution becomes very close to a one-hot distribution, which leads to severe vanishing gradient problems and poor training performance.
Correspondingly, there are two solutions. One is to divide by $\sqrt{d}$ after the dot product, as in NTK parameterization, to make the variance of $\boldsymbol{q}\cdot \boldsymbol{k}$ equal to 1. This prevents $e^3$ and $e^{-3}$ from being excessively large or small, so softmax doesn’t result in a one-hot distribution and vanishing gradients. This is the approach used in regular Transformers like BERT’s Self Attention. The other solution is not to divide by $\sqrt{d}$, but instead, when initializing the fully connected layers for $\boldsymbol{q}$ and $\boldsymbol{k}$, their initialization variance should be further divided by an additional $\sqrt{d}$. This also ensures that the initial variance of $\boldsymbol{q}\cdot \boldsymbol{k}$ becomes 1. T5 adopted this approach.
Residual Connections#
Finally, we must discuss the design related to residual connections $x + F(x)$. It can be easily proven that if the variance of $x$ (and similarly, the second moment) is $\sigma_1^2$ and the variance of $F(x)$ is $\sigma_2^2$, and assuming they are independent, then the variance of $x + F(x)$ is $\sigma_1^2 + \sigma_2^2$. This means that residual connections further amplify variance, so we also need corresponding strategies to reduce their variance.
One naive approach is to directly add a Normalization operation after the residual connection:
$$ x_{t+1} = \text{Norm}(x_t + F_t(x_t)) $$This can be called the Post Norm structure, which is also the design used in the original Transformer and BERT. However, although this approach stabilizes the variance during forward propagation, it actually severely weakens the identity branch of the residual connection, thus losing the “easy to train” advantage of residual connections. It usually requires warmup and setting sufficiently small learning rates for convergence.
How can this be understood? Assume that initially, both $x$ and $F(x)$ have a variance of 1. Then the variance of $x+F(x)$ is 2. The Normalization operation is responsible for reducing the variance back to 1. This implies that in the initial stage, Post Norm is equivalent to:
$$ \begin{aligned} x_{t+1} =&\, \frac{x_t + F_t(x_t)}{\sqrt{2}} \\ =&\, \frac{x_{l-1}}{\sqrt{2}} + \frac{F_{l-1}(x_{l-1})}{\sqrt{2}} \\ =&\, \frac{x_{l-2}}{2} + \frac{F_{l-2}(x_{l-2})}{2} + \frac{F_{l-1}(x_{l-1})}{\sqrt{2}} \\ =&\, \cdots \\ =&\, \frac{x_0}{2^{l/2}} + \frac{F_0(x_0)}{2^{l/2}} + \frac{F_1(x_1)}{2^{(l-1)/2}} + \frac{F_2(x_2)}{2^{(l-2)/2}} + \cdots + \frac{F_{l-1}(x_{l-1})}{2^{1/2}} \end{aligned} $$Do you see the problem? The original intention of residual connections was to provide a “green channel” for earlier layers, allowing gradients to propagate back more directly. However, in Post Norm, this “green channel” is severely weakened; channels closer to the beginning have smaller weights, making the residual connection “exist in name only,” and thus still difficult to train. For related analysis, you can refer to the paper 《On Layer Normalization in the Transformer Architecture》.
A targeted improvement is called Pre Norm, which is based on the idea of “standardizing only when needed.” Its form is:
$$ x_{t+1} = x_t + F_t(\text{Norm}(x_t)) $$Similarly, expanding iteratively, we can assume that in the initial stage:
$$ x_l = x_0 + F_0(x_0) + F_1(x_1/\sqrt{2}) + F_2(x_2/\sqrt{3}) + \cdots + F_{l-1}(x_{l-1}/\sqrt{l}) $$In this way, at least each residual channel is treated equally, and the effect of residual connections will be more pronounced than with Post Norm, making it easier to optimize. Of course, the variance of the final $x_l$ will be very large in this case, so an additional Normalization is still needed for $x_l$ before connecting to the prediction layer.
In my opinion, neither Post Norm nor Pre Norm is perfect, as neither can maintain an identity function in the initial stage. From my perspective, the most elegant method should be to introduce a scalar parameter $\alpha_t$, initialized to 0, such that:
$$ x_{t+1} = x_t + \alpha_t F_t(x_t) $$And then gradually update $\alpha_t$. This way, we can ensure the model acts as an identity function in the initial stage, thus avoiding variance issues. This trick later appeared in two papers: in 《Batch Normalization Biases Residual Blocks Towards the Identity Function in Deep Networks》 it was called SkipInit, and in 《ReZero is All You Need: Fast Convergence at Large Depth》 it was called ReZero. The two papers were published less than a month apart, and their results both showed that this treatment could essentially replace Normalization operations within residual connections. Furthermore, 《Fixup Initialization: Residual Learning Without Normalization》 proposed a method called Fixup, which initializes the last layer of each residual branch with all zeros, sharing some commonalities with SkipInit and ReZero.
For the update of $\alpha_t$, both SkipInit and ReZero treat it as a model parameter that updates along with other model parameters, which was also my initial thought. Later, I realized that the status of $\alpha_t$ is not equivalent to other parameters and cannot be generalized. For example, through the NTK parameterization introduced earlier, we can use a large learning rate for other parameters, but obviously $\alpha_t$ should not use a large learning rate. Moreover, we know that if training is successful, both Post Norm and Pre Norm perform well (corresponding to $\alpha_t=1$). Therefore, the choice of this residual mode is purely an initialization problem, not a fitting capability problem. Considering these points, I later simply allowed $\alpha_t$ to gradually increase with a fixed, very small step size until it reached $\alpha_t=1$ and then fixed it. In my experimental results, this update mode achieved the optimal results.
Summary (formatted)#
This article discussed issues related to model initialization, parameterization, and normalization, hoping to provide some reference for your model “alchemy” and hyperparameter tuning. The road of “alchemy” is endless. Besides these contents, there are many other adjustable things, such as learning rate, optimizer, and data augmentation. May all readers have a smooth journey in the path of “alchemy”~
@online{kexuefm-8620,
title={A Brief Discussion on Initialization, Parameterization, and Normalization of Transformers},
author={苏剑林},
year={2021},
month={08},
url={\url{https://kexue.fm/archives/8620}},
}