Skip to main content

VQing the Key Makes Transformer Complexity Linear

·2766 words
Table of Contents

This is a gemini-2.5-flash translation of a Chinese article.

It has NOT been vetted for errors. You should have the original article open in a parallel tab at all times.

By Su Jianlin | 2023-11-09 | 97842 readers

Efficient Transformer refers to all efforts aimed at reducing the quadratic complexity of Transformer models. Initially, it specifically focused on improvements to Attention, but later, more general approaches, such as Fourier transforms and linear RNNs, were also categorized under this umbrella. It’s undeniable that, in pursuit of lowering Transformer’s quadratic complexity, various experts have truly “crossed the Eight Immortals and shown their unique powers,” with a dazzling array of ingenious ideas blooming. I, too, have learned a great deal of theoretical knowledge from these endeavors. However, despite their theoretical brilliance, the field of Efficient Transformers has consistently remained lukewarm in practice, without any truly outstanding models emerging. In today’s era of LLM popularity, it has even gradually faded from public view, and from my own interest.

Nevertheless, a recent paper, “Transformer-VQ: Linear-Time Transformers via Vector Quantization”, truly impressed me. The authors ingeniously observed that by simply performing Vector Quantization (VQ) on the Key in standard Attention, the complexity automatically drops to linear! This linearization approach preserves the form of standard Attention, providing a perfect transition from standard to linear Attention, while maximizing the retention of standard Attention’s capabilities.

Efficient Transformer Challenges
#

It’s worth mentioning that this site started paying attention to Efficient Transformer related work relatively early, with the earliest article dating back to 2019, interpreting Sparse Transformer in a blog post titled “Born for Efficiency: From Standard Attention to Sparse Attention”. Subsequently, other blog posts written about Efficient Transformers include:

However, as mentioned at the beginning of this article, despite numerous works in Efficient Transformers and high hopes placed upon them, this field has not seen any truly “breakthrough” achievements. The reasons for this might be:

  1. Many Efficient Transformers achieve speedup at the cost of performance degradation.
  2. The complexity reduction of many Efficient Transformers is only theoretical, with negligible practical speedup.
  3. Some Efficient Transformers are difficult to train for Causal LMs, making them unsuitable in today’s popular LLM landscape.
  4. The emergence of Flash Attention shows that even standard Transformers still have significant room for speed optimization.

VQ It Up
#

So, why does Transformer-VQ possess “breakthrough” potential?

Simply put, Transformer-VQ “clusters” the Key vectors in Attention and approximates the original vectors with their respective cluster centers, after which Attention’s complexity becomes linear. This means Transformer-VQ only changes the form of the Key; the rest of the components (theoretically) remain entirely unchanged. Thus, it’s a linearization scheme with minimal modifications to Attention, and it clearly shows where the precision is lost due to linearization (i.e., the difference between the cluster center and the original vector).

With that preamble, let’s formally introduce Transformer-VQ. First, assuming $Q,K\in\mathbb{R}^{n\times d_k},V\in\mathbb{R}^{n\times d_v}$, standard Attention is:

$$ softmax\left(QK^{\top}\right)V $$

For simplicity, the scale factor is omitted here. Transformer-VQ changes this to:

$$ softmax\left(Q\hat{K}^{\top}\right)V,\quad \hat{K} = \color{skyblue}{\mathcal{VQ}}(K, C) \quad \text{(1)} $$

where $C\in\mathbb{R}^{c\times d_k}$ is a trainable parameter, serving as the VQ codebook. By the way, “VQ” here refers to the VQ in VQ-VAE. Readers unfamiliar with it can refer to “A Brief Introduction to VQ-VAE: Vector Quantized Variational Autoencoders” and “Embarrassingly Simple FSQ: ‘Rounding’ Surpasses VQ-VAE”; it won’t be re-introduced here. In short, after $\color{skyblue}{\mathcal{VQ}}$, the most direct effect is that each vector in $K$ becomes the closest vector in $C$. This means each vector in $\hat{K}$ is one of the vectors in $C$; mathematically, $K\in\mathbb{R}^{n\times d_k}$ transforms into $\hat{K}\in C^n$.

Encoder
#

Of course, if we directly implement Transformer-VQ according to Equation (1), the complexity is still quadratic. However, since each vector in $\hat{K}$ is one of the vectors in $C$, we can first compute $\exp\left(QC^{\top}\right)$ and then “pick out” the results corresponding to $\exp\left(Q\hat{K}{}^{\top}\right)$. Because the size of $C$ is fixed, the complexity of the key operation $QC^{\top}$ is linear. This is the principle behind Transformer-VQ’s linearization (we might call it the “picking out” trick).

As a preliminary step, let’s consider the Encoder case for bidirectional attention. Since

$$ softmax\left(QK^{\top}\right)V = \frac{\exp\left(QK^{\top}\right)V}{\exp\left(QK^{\top}\right)1_{n\times 1}} \quad \text{(2)} $$

Here, $1_{n\times 1}$ denotes an $n\times 1$ matrix of all ones. The denominator can be seen as a special form of the numerator, so we only need to consider the numerator $\exp\left(QK^{\top}\right)V$. Since each vector in $\hat{K}$ is one of the vectors in $C$, we can construct a one-hot matrix $\Delta\in \{0,1\}^{n\times c}$, where $\Delta_i\in\{0,1\}^c$ is a one-hot vector. If the dimension with 1 is $j$, then $\hat{K}_i = C_j$, so $\hat{K}=\Delta C$.

Thus, for Transformer-VQ, we have:

$$ \exp\left(Q\hat{K}{}^{\top}\right)V = \exp\left(QC^{\top}\Delta^{\top}\right)V = \exp\left(QC^{\top}\right)\Delta^{\top}V = \exp\left(QC^{\top}\right)(\Delta^{\top}V) $$

Evidently, the crucial part here is the second equality! For a one-hot matrix $\Delta$, multiplying by its transpose on the right can separate it from the $\exp$ term. This is the mathematical expression of the “picking out” trick mentioned in the principle. Once separated, due to matrix multiplication associativity, $\Delta^{\top}$ can first multiply $V$ to obtain a $c\times d_v$ matrix, and $\exp\left(QC^{\top}\right)$ is an $n\times c$ matrix. Multiplying by $\Delta^{\top}V$ yields an $n\times d_v$ matrix. The total theoretical complexity is $\mathcal{O}(ncd_k + ncd_v + ncd_v) = \mathcal{O}(n)$.

Finally, substituting the result of $\exp\left(Q\hat{K}{}^{\top}\right)V$ into Equation (2), we can compute the complete Attention result (possibly with some details to avoid overflow). The entire process can be completed with linear complexity.

Decoder
#

Now let’s consider the Decoder for unidirectional attention, which is crucial for training generative models and is the foundation of current LLMs. With the Encoder as a groundwork, understanding the Decoder is not so difficult. Assuming $Q_i, \hat{K}_j \in \mathbb{R}^{1\times d_k}, V_j\in\mathbb{R}^{1\times d_v}$ are row vectors from the sequences $Q,\hat{K},V$, then for the numerator of the Decoder, we have:

$$ \begin{aligned}O_i =&\, \sum_{j\leq i}\exp\left(Q_i\hat{K}{}_j^{\top}\right)V_j = \sum_{j\leq i}\exp\left(Q_i C^{\top}\Delta_j^{\top}\right)V_j \\=&\, \sum_{j\leq i}\exp\left(Q_i C^{\top}\right)\Delta_j^{\top}V_j = \exp\left(Q_i C^{\top}\right)\sum_{j\leq i}\Delta_j^{\top}V_j\end{aligned} $$

If $c\times d_v$ is not too large, the last expression can be computed directly using a $\text{cumsum}$ operator. However, in general, especially with Multi-Head Attention, to save VRAM, it’s usually converted into an RNN for recursive computation, similar to the “Autoregressive Generation” section in “Exploring Linear Attention: Must Attention have a Softmax?”. That is, let $U_i = \sum_{j\leq i}\Delta_j^{\top}V_j\in\mathbb{R}^{c\times d_v}$, then:

$$ O_i = \exp\left(Q_i C^{\top}\right)U_i,\quad U_i = U_{i-1} + \Delta_i^{\top}V_i \quad \text{(3)} $$

During inference, this step-by-step recursive calculation is fine. However, training step-by-step can be slow. We can switch to block-by-block calculation for acceleration: without loss of generality, let $n=lm$, where $l$ is the block_size and $m$ is the number of blocks. A block slice $[il:(i+1)l]$ is abbreviated as $[i]$. Then:

$$ \begin{aligned}O_{[i]} =&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + M\right)V_{[i]} + \sum_{j\lt i}\exp\left(Q_{[i]}\hat{K}{}_{[j]}^{\top}\right)V_{[j]} \\=&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + M\right)V_{[i]} + \sum_{j\lt i}\exp\left(Q_{[i]}C^{\top}\Delta_{[j]}^{\top}\right)V_{[j]} \\=&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + M\right)V_{[i]} + \exp\left(Q_{[i]}C^{\top}\right)\sum_{j\lt i}\Delta_{[j]}^{\top}V_{[j]}\end{aligned} $$

where $M\in\{-\infty,0\}^{l\times l}$ is a lower-triangular Attention Mask, meaning $M_{i,j}=0$ when $i \geq j$ and $M_{i,j}=-\infty$ otherwise. Therefore, letting $U_i = \sum_{j\lt i}\Delta_{[j]}^{\top}V_{[j]}$, we have:

$$ \begin{aligned}O_{[i]} =&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + B_{[i,i]}\right)V_{[i]} + \exp\left(Q_{[i]}\hat{K}{}_{[i-1]}^{\top} + B_{[i,i-1]}\right)V_{[i-1]} + \exp\left(Q_{[i]}C^{\top}\right)U_{i-2}\\[5pt]U_i =&\, U_{i-1} + \Delta_{[i]}^{\top}V_{[i]}\end{aligned} \quad \text{(4)} $$

This way, we reduce the number of recursion steps to $m$, enabling full utilization of hardware parallelism while maintaining linear efficiency. The denominator can be calculated in the same way, and finally, dividing them yields the complete Attention result.

Local Enhancement
#

Is that all? Not quite. If it were just this, Transformer-VQ might not be very different from previous kernelized attention models like Performer, which are based on matrix decomposition. When the sequence length $n$ is much larger than the codebook size $c$, by the pigeonhole principle, some code vectors will inevitably appear repeatedly. It can even be reasonably guessed that all code vectors should be uniformly distributed throughout the sequence. This implies that the attention for neighboring tokens would be the same as for certain distant tokens, meaning the model cannot distinguish proximity. This is essentially the low-rank problem inherent in all kernelized attention mechanisms.

Existing experience tells us that for language models, neighboring tokens are often more important than distant ones. Therefore, a good language model architecture should have the ability to distinguish proximity. To this end, Transformer-VQ chooses to add a Sliding Window shaped Attention Bias (denoted as $B$) after $Q\hat{K}$ to weight neighboring tokens, as shown in the figure below:

Window Attention Bias示意图
Diagram of Window Attention Bias

From the last figure, it can be seen that if the Window size is directly set to the block size $l$, i.e., $B_{i,j}=0$ when $i < j$ or $i - j \leq l$, then when calculating block-by-block, the matrix $B$ at most affects the two nearest blocks. More distant blocks can still be linearized using the “picking out” trick. For convenience in the following derivation, we denote $B_{[i,j]} = B_{[il:(i+1)l,jl:(j+1)l]}$. Then:

$$ \begin{aligned}O_{[i]} =&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + B_{[i,i]}\right)V_{[i]} + \exp\left(Q_{[i]}\hat{K}{}_{[i-1]}^{\top} + B_{[i,i-1]}\right)V_{[i-1]} + \sum_{j\lt i-1}\exp\left(Q_{[i]}\hat{K}{}_{[j]}^{\top}\right)V_{[j]} \\=&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + B_{[i,i]}\right)V_{[i]} + \exp\left(Q_{[i]}\hat{K}{}_{[i-1]}^{\top} + B_{[i,i-1]}\right)V_{[i-1]} + \sum_{j\lt i-1}\exp\left(Q_{[i]}C^{\top}\Delta_{[j]}^{\top}\right)V_{[j]} \\=&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + B_{[i,i]}\right)V_{[i]} + \exp\left(Q_{[i]}\hat{K}{}_{[i-1]}^{\top} + B_{[i,i-1]}\right)V_{[i-1]} + \exp\left(Q_{[i]}C^{\top}\right)\sum_{j\lt i-1}\Delta_{[j]}^{\top}V_{[j]}\end{aligned} $$

So, it is clear that (with the convention that $V_{[-1]},U_{[-1]},U_{[-2]}$ are all-zero matrices):

$$ \begin{aligned}O_{[i]} =&\, \exp\left(Q_{[i]}\hat{K}{}_{[i]}^{\top} + B_{[i,i]}\right)V_{[i]} + \exp\left(Q_{[i]}\hat{K}{}_{[i-1]}^{\top} + B_{[i,i-1]}\right)V_{[i-1]} + \exp\left(Q_{[i]}C^{\top}\right)U_{i-2}\\[5pt]U_i =&\, U_{i-1} + \Delta_{[i]}^{\top}V_{[i]}\end{aligned} \quad \text{(4)} $$

I believe that the introduction of $B$ is key to Transformer-VQ distinguishing itself from other kernelized attention mechanisms. To reduce the number of parameters and support variable-length generation, we constrain the non-zero part of $B$ to be a “Toeplitz matrix,” meaning $B_{i,j}$ is a function of $i-j$. In this case, $B$ acts like an additive relative position encoding. In addition to this approach, one could also consider using my previously proposed ReRoPE, which is a windowed version of rotary position encoding, and has the same relative position encoding shape as $B$.

Gradient Backpropagation
#

Wait, it seems we forgot something. Readers familiar with VQ-VAE know that “each vector in $\hat{K}$ is one of the vectors in $C$” is only for forward propagation; for backward propagation, the original $K$ is used. This means that even if $\hat{K}_j$ at different positions equals the same $C_k$, their gradients are not equal. This is called STE (Straight-Through Estimator). Due to the existence of STE, the “picking out” trick can theoretically only be used during inference, and cannot linearize during training.

Are there no other options? Indeed, if we insist on obtaining precise gradient results, there is no linearly efficient solution. However, considering that VQ’s gradients are themselves approximate, it seems unnecessary for Attention to obtain precise gradients. So, the authors devised a compromise: still perform recursive calculation according to Equation (4), but only use STE for the first two terms (the Key sequence can receive gradients), while the gradients of $U_{i-1}$ are directly stopped (using the stop_gradient operator). This maintains the model’s linearity while retaining the most important gradients (the two nearest blocks), which is a fairly reasonable approximation scheme. From this perspective, Transformer-VQ is quite similar to Transformer-XL, which also stops gradients for the historical window during recursion, meaning the historical window participates in recursive computation but does not pass gradients.

After solving the gradient backpropagation problem, combining the autoregressive cross-entropy loss with the auxiliary loss from VQ used to update the codebook yields the complete training objective. Of course, for codebook updates, Transformer-VQ adopts a direct moving average scheme, so only an auxiliary loss for Key is added. Readers familiar with VQ-VAE will understand these details after a quick look at the original paper.

Experimental Results
#

In this section, we examine the experimental results from the original paper. The authors have open-sourced their code here:

Github: https://github.com/transformer-vq/transformer_vq

It is worth noting that the base architecture for the authors’ VQ implementation is not the conventional MHA (Multi-Head Attention), but rather the GAU (Gated Attention Unit) + Softmax, which I have always highly endorsed. A more accurate name for Transformer-VQ would be “GAU-VQ.” Readers unfamiliar with GAU can refer to “FLASH: Perhaps the Most Interesting Efficient Transformer Design Recently” and “It is said that Attention and Softmax go better together~”. Simply put, GAU itself offers higher efficiency than MHA, and combined with the VQ trick, it becomes even more powerful.

In terms of experiments, the authors conducted evaluations on language models (ENWIK8, PG-19) and image generation (IMAGENET64). In all experiments, the codebook size was $c=512$. The maximum model size was 1.3B parameters, which, while not comparable to mainstream large models, is not small for research purposes. The overall experimental results are excellent:

PG-19的实验结果
PG-19 Experimental Results
IMAGENET64的实验结果
IMAGENET64 Experimental Results

Finally, it’s surprising that Transformer-VQ has only one author, whose affiliation is “Independent Researcher.”

Further Thoughts
#

I found that Transformer-VQ connects to many research topics, which is one of the reasons I appreciate it so much.

First, I reiterate my praise for the author’s astonishing insight. The discovery that “simply VQing the Key makes Transformer complexity linear” is truly wonderful. It achieves a natural transition from standard Attention to linear Attention, and by adding an Attention Bias, it becomes more effective than many kernelized Attention mechanisms. Furthermore, the “clustering” approach via VQ is more sophisticated than methods like Linformer and Nyströmformer, because it prevents future information leakage and can naturally be used for Causal language models.

We know that VQ essentially converts sequences into discrete IDs, which is very similar to the role of a Tokenizer. From this perspective, Transformer-VQ, like models such as MegaByte, incorporates the Tokenizer directly into the model. Compared to MegaByte, the VQ operation is more similar and intuitive to our traditional understanding of a Tokenizer. Therefore, Transformer-VQ is actually very suitable for training “No Tokenizer” models that directly take Bytes as input. In fact, the ENWIK8 experiment mentioned above used Bytes input, and Transformer-VQ significantly outperformed MegaByte.

Compared to the recently released RetNet, Transformer-VQ does not have explicit long-range decay, so its long-context capability might be better. At the same time, since the Key has undergone VQ and all keys belong to a finite set, there won’t be any unseen Keys, so the length extrapolation capability is likely to be better. Although Transformer-VQ’s underlying architecture, GAU, is single-head, its model memory state size during recursion is $\Delta_i^{\top}V_i\in\mathbb{R}^{c\times d_v}$, which, in default settings, is larger than that of multi-head RetNet (RetNet’s memory state size is $nd_k^2$, and in default settings $d_v = 2nd_k$). Therefore, the memory capacity is theoretically sufficient.

Since the previous article was about “Embarrassingly Simple FSQ: ‘Rounding’ Surpasses VQ-VAE”, some readers might wonder if the simpler FSQ could replace VQ. I believe it would be difficult, and the reasons were actually given in the previous article: First, $c=512$ is still within the range where VQ outperforms FSQ in terms of encoding quantity, so replacing it with FSQ would likely degrade performance. Second, since the Key of each Attention layer needs to be VQ’d, on average, the VQ encoder and decoder are not very strong. In this situation, VQ offers higher approximation accuracy; FSQ is more suitable for scenarios where both the encoder and decoder are sufficiently strong. Third, Transformer-VQ requires the center vectors after VQ, not just the IDs, for the Key, while FSQ directly yields IDs, making it harder to recover approximate center vectors.

In addition, using VQ instead of FSQ makes it possible for Transformer-VQ to be fine-tuned from existing pre-trained models like LLAMA2, rather than just trained from scratch. Because VQ has distinct geometric significance and many commonalities with K-Means, we can start from an existing pre-trained model, select some samples to compute the Keys, perform K-Means on the Keys to obtain center vectors as the codebook initialization, and then add VQ to the original model for fine-tuning. However, Transformer-VQ is not well-suited for RoPE, so as mentioned earlier, RoPE models should be switched to ReRoPE before VQ, in which case adding a bias might not be necessary.

In summary, in my view, Transformer-VQ stands out among many Efficient Transformer works as one of the most unique, excellent, and deeply promising solutions.

Summary (formatted)
#

This article introduced an Efficient Transformer solution called Transformer-VQ. It is based on the observation that “simply VQing the Key makes Transformer complexity linear.” I personally believe this is a very unique and brilliant linearization approach, and its experimental results are excellent. It can be understood as a more sophisticated linear Attention/RNN model, or as an Attention model with a “trainable Tokenizer.”

@online{kexuefm-9844,
        title={VQing the Key Makes Transformer Complexity Linear},
        author={苏剑林},
        year={2023},
        month={11},
        url={\url{https://kexue.fm/archives/9844}},
}