Skip to main content

15. Key Normalization Boosts Length Extrapolation

·1983 words
Table of Contents
Road to a better Transformer - This article is part of a series.
Part 15: This Article
This is a gemini-2.5-flash-preview-04-17 translation of a Chinese article. Beware of potential errors.

Broadly speaking, current Transformer length extrapolation techniques can be divided into two categories: one is post-hoc modification, such as NTK-RoPE, YaRN, ReRoPE, etc. The characteristic of these methods is that they directly modify the inference model and can achieve a certain length extrapolation effect without fine-tuning. However, their drawback is that they cannot maintain the model’s identity within the training length. The other category is naturally pre-hoc modification, such as ALIBI, KERPLE, XPOS, and HWFA, which can achieve a certain length extrapolation without modification. However, the corresponding modifications need to be introduced before training, so they cannot be used for off-the-shelf models without fine-tuning, and whether these methods can Scale Up has not yet been widely recognized.

In this article, the author will introduce an unexpectedly discovered length extrapolation method - “KeyNorm” - which performs L2 Normalization on the Attention’s Key sequence. It clearly belongs to the pre-hoc modification category, but the modification to the Attention mechanism is very small, making it look very promising for Scaling Up.

Initial Motivation
#

The reason for saying “unexpected discovery” is that the original motivation for this modification was not length extrapolation, but an attempt to replace the scaling method in Scaled Dot-Product Attention. We know that the standard definition of Attention is (this article mainly considers the Causal scenario)

$$ \boldsymbol{o}_i = \frac{\sum_{j = 1}^i\exp\left(\frac{\boldsymbol{q}_i\cdot \boldsymbol{k}_j}{\sqrt{d}}\right)\boldsymbol{v}_j}{\sum_{j = 1}^i\exp\left(\frac{\boldsymbol{q}_i\cdot \boldsymbol{k}_j}{\sqrt{d}}\right)} $$

where $\boldsymbol{q}_i,\boldsymbol{k}_j\in\mathbb{R}^d$. We have explained and even generalized the Scale factor $\frac{1}{\sqrt{d}}$ multiple times, for example, in “A Brief Discussion on Transformer Initialization, Parameterization, and Normalization”, “Viewing Attention’s Scale Operation from Entropy Invariance”, “Viewing Attention’s Scale Operation from Gradient Maximization”, etc. The standard derivation is based on the assumption that “$\boldsymbol{q}_i,\boldsymbol{k}_j$ are independently sampled from a distribution with mean 0 and variance 1”. Under this assumption, we also have

$$ \Vert\boldsymbol{q}_i\Vert\approx \sqrt{d},\quad \Vert\boldsymbol{k}_j\Vert\approx \sqrt{d} $$

This is because

$$ \Vert\boldsymbol{x}\Vert^2 = \sum_{i=1}^d x_i^2 = d\times\frac{1}{d}\sum_{i=1}^d x_i^2\approx d\,\mathbb{E}_{x\sim\mathcal{N}(0,1)}[x^2] = d $$

For related generalizations, you can refer to “The Astonishing Johnson-Lindenstrauss Lemma: Theory”. This approximation means that in the initial stage of Attention, equation $\text{eq:sdpa}$ has the same effect as the following two variants:

$$ \begin{align} \color{red}{\text{Q}}\text{uery}\color{red}{\text{N}}\text{orm:}\quad\boldsymbol{o}_i =&\, \frac{\sum_{j = 1}^i\exp\left(\tilde{\boldsymbol{q}}_i\cdot \boldsymbol{k}_j\right)\boldsymbol{v}_j}{\sum_{j = 1}^i\exp\left(\tilde{\boldsymbol{q}}_i\cdot \boldsymbol{k}_j\right)},\qquad \tilde{\boldsymbol{q}}_i = \frac{\boldsymbol{q}_i}{\Vert\boldsymbol{q}_i\Vert} \\[5pt] \color{red}{\text{K}}\text{ey}\color{red}{\text{N}}\text{orm:}\quad\boldsymbol{o}_i =&\, \frac{\sum_{j = 1}^i\exp\left(\boldsymbol{q}_i\cdot \tilde{\boldsymbol{k}}_j\right)\boldsymbol{v}_j}{\sum_{j = 1}^i\exp\left(\boldsymbol{q}_i\cdot \tilde{\boldsymbol{k}}_j\right)},\qquad \tilde{\boldsymbol{k}}_j = \frac{\boldsymbol{k}_j}{\Vert\boldsymbol{k}_j\Vert} \end{align} $$

Therefore, the idea arose to verify which of these two variants and the standard equation $\text{eq:sdpa}$ is better. For convenience of description, we can refer to them as “Query/Key-Normalized Dot-Product Attention”, abbreviated as “QNA” and “KNA” respectively.

Furthermore, since we can do QueryNorm and KeyNorm, we can naturally consider normalizing both, so we also conducted experiments on the following “Scaled Cosine Attention (CosA)”:

$$ \boldsymbol{o}_i = \frac{\sum_{j = 1}^i\exp\left(\lambda\,\tilde{\boldsymbol{q}}_i\cdot \tilde{\boldsymbol{k}}_j\right)\boldsymbol{v}_j}{\sum_{j = 1}^i\exp\left(\lambda\,\tilde{\boldsymbol{q}}_i\cdot \tilde{\boldsymbol{k}}_j\right)} = \frac{\sum_{j = 1}^i\exp\left(\lambda\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)\right)\boldsymbol{v}_j}{\sum_{j = 1}^i\exp\left(\lambda\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)\right)} $$

Here, $\lambda$ uses the result from “Viewing Attention’s Scale Operation from Gradient Maximization”, which is $\lambda = 4\log n$ (the original text was 3.5, but the training length below is relatively small, so changing it to 4 is more accurate), where $n$ is fixed as half of the training length, or dynamically takes the position id plus 1.

Look at the Results First
#

Following the experimental setup for length extrapolation done previously, all are small models with 100 million parameters, GAU architecture, trained for the same number of steps (limited time, the model is not fully trained at this number of steps), trained for a length of 512, and considering extrapolation to length 4096. The experimental results are shown in the table below. Baseline is equation $\text{eq:sdpa}$, and $\text{-}\log n$ means adding the length-dependent scaling factor introduced in “Viewing Attention’s Scale Operation from Entropy Invariance”. The evaluation metric is language model per-token accuracy, higher is better.

Test Length512(Train)4096(Repeated)4096(Non-repeated)
Baseline49.41%24.17%23.16%
Baseline-$\log n$49.40%24.60%24.02%
QNA49.55%22.45%22.18%
QNA-$\log n$49.42%19.55%18.74%
KNA49.60%61.08%47.69%
KNA-$\log n$49.58%63.17%46.40%
CosA49.73%58.90%46.98%
CosA-$\log n$49.67%64.74%48.95%

From the table, we can see that: 1. Both QueryNorm and KeyNorm achieve better results on the training length, although this advantage is very slight and will likely be negligible as training progresses further, but this advantage is very stable, suggesting the possibility of making training smoother; 2. KeyNorm significantly improves length extrapolation, which is the “unexpected surprise” in the experimental results!

Note that unlike NTK-RoPE, YaRN, etc., which require modifying the inference stage model, the length extrapolation of KNA and CosA is achieved without any modifications during inference. Therefore, some readers might wonder, since KNA and CosA already have such good extrapolation effects without modification during inference, what if they are combined with extrapolation techniques like NTK-RoPE, YaRN, etc.? Will the effect be “even better”? The author also tested this, and the results are shown in the table below:

Test Length512(Train)4096(Repeated)4096(Non-repeated)
Baseline49.41%24.17%23.16%
Baseline-NTK49.41%60.57%42.20%
Baseline-YaRN49.41%80.10%47.45%
Baseline-ReRoPE49.41%76.11%47.82%
Baseline-$\log n$49.40%24.60%24.02%
Baseline-$\log n$-NTK49.40%75.86%47.06%
Baseline-$\log n$-YaRN49.40%82.57%46.52%
Baseline-$\log n$-ReRoPE49.40%85.47%48.87%
QNA49.55%22.45%22.18%
QNA-NTK49.55%52.28%39.88%
QNA-YaRN49.55%82.53%47.50%
QNA-ReRoPE49.55%78.22%47.72%
QNA-$\log n$49.42%19.55%18.74%
QNA-$\log n$-NTK49.42%57.44%41.56%
QNA-$\log n$-YaRN49.42%80.08%45.16%
QNA-$\log n$-ReRoPE49.42%84.71%48.31%
KNA49.60%61.08%47.69%
KNA-NTK49.60%64.44%43.02%
KNA-YaRN49.60%84.19%47.44%
KNA-ReRoPE49.60%77.76%47.73%
KNA-$\log n$49.58%63.17%46.40%
KNA-$\log n$-NTK49.58%79.05%47.43%
KNA-$\log n$-YaRN49.58%83.95%47.16%
KNA-$\log n$-ReRoPE49.58%85.48%48.78%
CosA49.73%58.90%46.98%
CosA-NTK49.73%62.50%42.77%
CosA-YaRN49.73%83.40%47.80%
CosA-ReRoPE49.73%77.82%47.80%
CosA-$\log n$49.67%64.74%48.39%
CosA-$\log n$-NTK49.67%78.97%47.46%
CosA-$\log n$-YaRN49.67%82.28%45.72%
CosA-$\log n$-ReRoPE49.67%85.67%48.39%

This table is quite verbose, mainly to give everyone a comprehensive sense of the performance differences of mainstream length extrapolation techniques. You can compare the dimensions you are interested in, but note that if you are looking at length extrapolation effects, you should primarily focus on the “Non-repeated” column and secondarily on the “Repeated” column. From the table above, the results are quite unexpected. KeyNorm seems to be “immune” to existing RoPE extrapolation techniques. Techniques like NTK, YaRN, etc., stacked on top do not show significant improvement, and may even decrease performance. However, overall, the “Repeated” column still shows significant improvement, while the “Non-repeated” column does not. These results indicate that KeyNorm still has the problem of not being able to effectively identify positions beyond the training length (hence the “Repeated” results are not high), but it effectively avoids the PPL explosion problem (hence the “Non-repeated” results are still good).

This might be good news for researchers working on Long Context: on one hand, KeyNorm, unlike ALIBI, KERPLE, etc., does not require adding Local constraints and makes no modifications after training is completed, essentially being a “free lunch,” and it even seems to improve training performance after adding KeyNorm; on the other hand, because it is non-Local, it can continue training on longer texts, and during continued training, there is no longer a need to choose between PI or ABF. For KeyNorm, just don’t change anything.

Principle Analysis
#

Although this was an unexpected discovery, we still need to try to explain it, otherwise, it will remain just an accident. So in this section, we try to think about why KeyNorm helps with length extrapolation.

Let’s return to equation $\text{eq:sdpa}$. The relevance score between the $i$-th token and the $j$-th token is calculated by the dot product:

$$ s(j|i) = \boldsymbol{q}_i\cdot \boldsymbol{k}_j = \Vert\boldsymbol{q}_i\Vert \Vert\boldsymbol{k}_j\Vert \cos(\boldsymbol{q}_i,\boldsymbol{k}_j),\quad p(j|i) = \frac{\exp\left(\frac{s(j|i)}{\sqrt{d}}\right)}{\sum_{j=1}^i \exp\left(\frac{s(j|i)}{\sqrt{d}}\right)} $$

In the second equality, we decompose it from a geometric perspective into the product of their magnitudes and the cosine of the angle between them. Attention $p(j|i)$ is a conditional probability. $\Vert\boldsymbol{q}_i\Vert$ is only related to the current position $i$, and it does not change the relative size of attention, but only changes the sparsity; $\Vert\boldsymbol{k}_j\Vert$ has the ability to change the relative size of $p(j|i)$, but it does not involve the interaction between $i$ and $j$, and can be used to express some absolute signals, such as Scissorhands indicating that the attention of certain absolute position tokens will always be high, which could be expressed by $\Vert\boldsymbol{k}_j\Vert$; the remaining $\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)$ is used to express the interaction between $i$ and $j$, and it is the term with the most degrees of freedom.

It is clear that to increase the relative importance of a certain position $j$, the model has two choices: 1. Increase the magnitude $\Vert\boldsymbol{k}_j\Vert$; 2. Increase $\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)$, i.e., reduce the angle between $\boldsymbol{q}_i$ and $\boldsymbol{k}_j$. However, due to the existence of the “curse of dimensionality”, significantly changing the angle size in a high-dimensional space is relatively not so easy, so if it can be achieved by increasing the magnitude $\Vert\boldsymbol{k}_j\Vert$, the model will prioritize completing it by increasing the magnitude $\Vert\boldsymbol{k}_j\Vert$. The direct consequence of this is that the training of $\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)$ may not be sufficient.

Here the author makes a conjecture:

Insufficient training of $\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)$ is the main reason why Attention cannot length extrapolate.

Insufficient training of $\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)$ means that the angles of the trained $\boldsymbol{q}_i$ and $\boldsymbol{k}_j$ are only a limited set, and when performing length extrapolation, it has to face a larger set, thus failing to make correct predictions. A careful consideration of the derivation in the YaRN paper will reveal that the reason NTK and YaRN are effective is that they modify the RoPE implementation during the inference stage, causing the angles of $\boldsymbol{q}_i$ and $\boldsymbol{k}_j$ to fall within the limited set of the original training stage, avoiding facing unfamiliar larger sets, and transforming extrapolation into interpolation; ReRoPE is even more direct, directly truncating relative positions outside the Window, which ensures that the position encodings in the inference stage are not “unfamiliar”. These techniques indirectly validate this conjecture to some extent.

From this conjecture, the reason for KeyNorm’s length extrapolation becomes simple. Whether it is KNA which only performs KeyNorm, or CosA which performs both QueryNorm and KeyNorm, they both exclude $\Vert\boldsymbol{k}_j\Vert$ from the definition of Attention. Therefore, to change the relative importance of $j$, the model only has one option: “adjust $\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)$”. This will make the model train and utilize $\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)$ more fully, thereby indirectly promoting length extrapolation. Furthermore, the author also experimented with the “KeyNorm + NoPE” combination, but did not find length extrapolation properties. This indicates that RoPE also plays an important role in KeyNorm’s length extrapolation. In fact, this is not difficult to understand. RoPE rotates $\boldsymbol{q}_i$ and $\boldsymbol{k}_j$, which is more conducive to expanding the range of $\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)$ during training, thus making the training of $\cos(\boldsymbol{q}_i,\boldsymbol{k}_j)$ more sufficient.

Have there been any works that have already tried QueryNorm and KeyNorm? Yes. The 2020 paper “Query-Key Normalization for Transformers” experimented with CosA. The paper also proposed a similar length logarithmic Scale factor, but did not discuss the length extrapolation problem. Furthermore, Google’s paper “Scaling Vision Transformers to 22 Billion Parameters” earlier this year also added Norm to Query and Key, but added LayerNorm. LayerNorm or RMSNorm both have learnable gamma parameters, which means the magnitude of the vector after Norm is not necessarily constant, so it is difficult to say whether it can achieve the same length extrapolation effect as this article.

Summary (formatted)
#

This article introduces an unexpectedly discovered length extrapolation method by the author, “KeyNorm” - applying L2 normalization to the Attention’s Key sequence. It achieves better performance on the training length and shows significant improvement in length extrapolation. It belongs to the “pre-hoc modification” methods. Compared to other pre-hoc modification methods such as ALIBI, KERPLE, etc., it has no Local constraints, making it more promising for Scaling Up. Compared to “post-hoc modification” methods such as NTK-RoPE, YaRN, etc., it does not lose performance within the training length during extrapolation.

@online{kexuefm-9859,
        title={Transformer Upgrade Path: 15. Key Normalization Boosts Length Extrapolation},
        author={苏剑林},
        year={2023},
        month={11},
        url={\url{https://kexue.fm/archives/9859}},
}
Road to a better Transformer - This article is part of a series.
Part 15: This Article