This is a gemini-2.5-flash
translation of a Chinese article.
It has NOT been vetted for errors. You should have the original article open in a parallel tab at all times.
I remember seeing VQ-VAE a long time ago, but at the time, I wasn’t particularly interested in it. Recently, two things rekindled my interest. First, VQ-VAE-2 achieved generation quality matching BigGAN (as reported by Synced); second, when I was recently reading an NLP paper, 《Unsupervised Paraphrasing without Translation》, I noticed that VQ-VAE was also used in it. These two events suggest that VQ-VAE should be quite a versatile and interesting model, so I decided to read it carefully.
Model Overview#
VQ-VAE (Vector Quantised - Variational AutoEncoder) first appeared in the paper 《Neural Discrete Representation Learning》. Like VQ-VAE-2, both are masterpieces from the Google team.
Interesting but Esoteric#
As an autoencoder, a prominent feature of VQ-VAE is that its encoded vectors are discrete. In other words, each element of the resulting encoding vector is an integer. This is the meaning of “Quantised”, which we can call “quantization” (similar to “quantum” in quantum mechanics, both imply discretization).
The entire model is continuous and differentiable, yet the final encoded vectors are discrete, and the reconstruction effect appears very clear (as shown in the image at the beginning of the article). This suggests that VQ-VAE must contain some interesting and valuable techniques worth learning. However, after reading the original paper, I felt it was somewhat difficult to understand. This difficulty is not like the abstruse nature of the original ON-LSTM paper; rather, it has a sense of being “intentionally obscure”.
Firstly, once you read the entire paper, you’ll realize that VQ-VAE is actually an AE (AutoEncoder) rather than a VAE (Variational AutoEncoder). I don’t know what purpose the authors had in using probabilistic language to associate it with VAEs, as this clearly increased the difficulty of understanding the paper. Secondly, one of the core steps of VQ-VAE is the Straight-Through Estimator, an optimization technique for discretizing latent variables, which was not explained in sufficient detail in the original paper, making it necessary to read the source code to better understand what it means. Lastly, the core idea of the paper was not well articulated, giving the impression that it was purely introducing the model itself without introducing the model’s underlying philosophy.
PixelCNN#
To trace the idea behind VQ-VAE, one must discuss autoregressive models. It can be said that VQ-VAE’s approach to generative models originated from autoregressive models like PixelRNN and PixelCNN. These models noted that the images we want to generate are actually discrete rather than continuous. Taking cifar10 images as an example, it is a $32\times 32$
3-channel image, which means it is a $32\times 32\times 3$
matrix where each element is an integer from 0 to 255. In this way, we can view it as a sentence of length $32\times 32\times 3=3072$
, with a vocabulary size of $256$
. We can then use a language model approach to generate an image pixel-by-pixel, recursively (passing all preceding pixels to predict the next pixel). This is the so-called autoregressive method:
where $p(x_1),p(x_2|x_1),\dots,p(x_{3n^2}|x_1,x_2,\dots,x_{3n^2-1})$
are all 256-class classification problems, with different conditional dependencies.
PixelRNN and PixelCNN have been introduced in various online resources, so I won’t go into detail here. I feel one could actually ride on the Bert craze and try using PixelAtt (Attention) for this task. Research on autoregressive models primarily focuses on two aspects: one is how to design the recursive order to enable the model to generate samples better, because image sequences are not simply one-dimensional but at least two-dimensional, and often three-dimensional. In such cases, whether you generate “from left to right then top to bottom”, “from top to bottom then left to right”, “center first then outwards”, or other orders, significantly impacts the generation quality; the other is how to accelerate the sampling process. In the literature I’ve read, a relatively recent achievement in autoregressive models is the ICLR 2019 work 《Generating High Fidelity Images with Subscale Pixel Networks and Multidimensional Upscaling》.
The autoregressive method is robust and effective for probability estimation, but it has one fatal drawback: slow. Because it generates pixel by pixel, each pixel requires random sampling. The cifar10 example given above is already considered a small image; for image generation nowadays, at least $128\times 128\times 3$
is needed to be convincing. This is nearly 50,000 pixels in total (imagine generating a sentence of length 50,000), which would be extremely time-consuming if generated pixel by pixel. Furthermore, for such long sequences, neither RNN nor CNN models can adequately capture such long-range dependencies.
Another issue with primitive autoregressive models is that they severed the connections between categories. Although viewing each pixel as a 256-class classification problem is fine because pixels are discrete, in reality, the difference between consecutive pixels is very small, and a pure classification problem fails to capture this connection. More mathematically, our objective function, cross-entropy, is $-\log p_t$
. If the target pixel is 100, and I predict 99, since the category is different, $p_t$
will be close to 0, and $-\log p_t$
will be very large, leading to a huge loss. However, from a visual perspective, the difference between pixel values 100 and 99 is small, and such a large loss should not occur.
VQ-VAE#
To address the inherent shortcomings of autoregressive models, VQ-VAE proposes the solution: first reduce dimensionality, then model the encoding vector using PixelCNN.
Dimensionality Reduction and Discretization#
This approach might seem natural and unremarkable, but in reality, it’s anything but natural.
Because PixelCNN generates discrete sequences, if you want to model encoding vectors using PixelCNN, it means the encoding vectors must also be discrete. However, common dimensionality reduction techniques, such as autoencoders, produce continuous latent variables, which cannot directly generate discrete variables. At the same time, generating discrete variables often implies the problem of vanishing gradients. Furthermore, how can we ensure that the images reconstructed after dimensionality reduction and reconstruction are not distorted? If the distortion is too severe, even worse than a regular VAE, then VQ-VAE would have little value.
Fortunately, VQ-VAE does provide an effective training strategy to solve these two problems.
Nearest Neighbor Reconstruction#
In VQ-VAE, an $n\times n\times 3$
image $x$
is first passed into an $encoder$
, yielding a continuous encoding vector $z$
:
Here, $z$
is a vector of size $d$
. Additionally, VQ-VAE maintains an Embedding layer, also known as a codebook, denoted as
Here, each $e_i$
is a vector of size $d$
. Then, VQ-VAE maps $z$
to one of these $K$
vectors through nearest neighbor search:
We can denote the codebook vector corresponding to $z$
as $z_q$
, and we consider $z_q$
to be the final encoded result. Finally, $z_q$
is passed into a $decoder$
, hoping to reconstruct the original image $\hat{x}=decoder(z_q)$
.
The entire process is:
$$ x\xrightarrow{encoder} z \xrightarrow{\text{nearest neighbor}} z_q \xrightarrow{decoder}\hat{x} $$In this way, because $z_q$
is one of the vectors in the codebook $E$
, it is essentially equivalent to one of the $K$
integers $1,2,\dots,K$
. Thus, this entire process is equivalent to encoding the entire image into a single integer.
Of course, the above process is simplified. If only encoded into a single vector, reconstruction inevitably suffers distortion, and generalization is difficult to guarantee. Therefore, in actual encoding, multiple convolutional layers are directly used to encode $x$
into $m\times m$
vectors of size $d$
:
That is, the total size of $z$
is $m\times m\times d$
, still preserving spatial structure. Each vector is then mapped to one in the codebook using the aforementioned method, resulting in a $z_q$
of the same size, which is then used for reconstruction. In this way, $z_q$ is also equivalent to an $m\times m$ integer matrix, thus achieving discrete encoding.
Customizing Gradients#
We know that for a standard autoencoder, training can be done directly with the following loss:
$$ \Vert x - decoder(z)\Vert_2^2 $$However, in VQ-VAE, we use $z_q$
for reconstruction, not $z$
. So, it seems this loss should be used instead:
But the problem is that the construction of $z_q$
involves $\text{argmin}$
, an operation with no gradient. Therefore, if we use the second loss, we cannot update the $encoder$
.
In other words, our objective is actually to minimize $\Vert x - decoder(z_q)\Vert_2^2$
, but it is difficult to optimize. On the other hand, $\Vert x - decoder(z)\Vert_2^2$
is easy to optimize but not our target objective. What should we do then? Of course, a very crude method is to use both:
But this is not ideal, because minimizing $\Vert x - decoder(z)\Vert_2^2$
is not our goal and introduces additional constraints.
VQ-VAE uses a very clever and direct method called Straight-Through Estimator, which you can also call “straight-through estimation”. It originated from Benjio’s paper 《Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation》. In the original VQ-VAE paper, this paper was directly cited without much explanation. However, reading this original paper directly is not a user-friendly choice; it’s better to read the source code directly.
In fact, the idea of Straight-Through is very simple: during forward propagation, you can use the desired variable (even if it’s non-differentiable), and during backpropagation, use the gradient you designed for it. Based on this idea, our designed objective function is:
$$ \Vert x - decoder(z + sg[z_q - z])\Vert_2^2 $$where $sg$
means stop gradient, i.e., don’t compute its gradient. Thus, during forward propagation (loss calculation), it is directly equivalent to $decoder(z + z_q - z)=decoder(z_q)$
. During backpropagation (gradient calculation), since $z_q - z$
provides no gradient, it is equivalent to $decoder(z)$
, which allows us to optimize the $encoder$
.
By the way, based on this idea, we can customize gradients for many functions, for example, $x + sg[\text{relu}(x) - x]$
defines the gradient of $\text{relu}(x)$
as always 1, but for error calculation, it is equivalent to $\text{relu}(x)$
itself. Of course, using the same method, we can arbitrarily define the gradient for any function; whether it has practical value needs to be analyzed case by case for specific tasks.
Maintaining the Codebook#
It should be noted that, according to VQ-VAE’s nearest neighbor search design, we should expect $z_q$
and $z$
to be very close (in fact, each vector in the codebook $E$
appears similar to the cluster centers of various $z$
). However, this is not necessarily the case; even if $\Vert x - decoder(z)\Vert_2^2$
and $\Vert x - decoder(z_q)\Vert_2^2$
are both small, it does not mean that the difference between $z_q$
and $z$
is small (i.e., $f(z_1)=f(z_2)$
does not imply $z_1 = z_2$
).
Therefore, to make $z_q$
and $z$
closer, we can directly add $\Vert z - z_q\Vert_2^2$
to the loss:
Furthermore, we can be more meticulous. Since the codebook ($z_q$
) is relatively free, while $z$
must strive to ensure reconstruction quality, we should try to “make $z_q$
approach $z$
” rather than “make $z$
approach $z_q$
”. And because the gradient of $\Vert z_q - z\Vert_2^2$
is equal to the gradient with respect to $z_q$
plus the gradient with respect to $z$
, we can equivalently decompose it into
The first term is equivalent to fixing $z$
and making $z_q$
approach $z$
, while the second term, conversely, fixes $z_q$
and makes $z$
approach $z_q$
. Note that this “equivalence” refers to backpropagation (gradient calculation); for forward propagation (loss calculation), it is twice the original value. Based on our previous discussion, we hope to “make $z_q$
approach $z$
” more than “make $z$
approach $z_q$
”, so we can adjust the final loss weighting:
where $\gamma < \beta$
. In the original paper, $\gamma = 0.25 \beta$
was used.
(Note: The codebook can also be updated using a moving average approach; please refer to the original paper for details.)
Fitting the Code Distribution#
After all the above designs, we have finally encoded images into $m\times m$
integer matrices. Since this $m\times m$
matrix also, to some extent, preserves the positional information of the original input image, we can use an autoregressive model like PixelCNN to fit the encoded matrix (i.e., model the prior distribution). By obtaining the code distribution via PixelCNN, we can randomly generate a new encoded matrix, then map it to a 3-dimensional real-valued matrix $z_q$
(rows * columns * encoding dimension) through the codebook $E$
, and finally obtain an image via the $decoder$
.
Generally, the current $m\times m$
is much smaller than the original $n\times n\times 3$
. For example, when experimenting with CelebA data, an original $128\times 128\times 3$
image could be encoded into a $32\times 32$
code with minimal distortion. Therefore, modeling the encoded matrix with an autoregressive model is much easier than directly modeling the original image.
My Implementation#
This is my VQ-VAE implementation using Keras (Python 2.7 + Tensorflow 1.8 + Keras 2.2.4, with the model part referencing this):
The main part of this script only includes VQ-VAE encoding and reconstruction (the image at the beginning of the article was reconstructed by the author using this script, showing acceptable reconstruction results), but does not include modeling the prior distribution with PixelCNN. However, the comments at the end include an example of using Attention to model the prior distribution. The random sampling results are as follows:
The results partially indicate that such random sampling is feasible, but the generation quality is not excellent. The reason I used PixelAtt instead of PixelCNN is that in my implementation, PixelCNN performed much worse than PixelAtt, so PixelAtt had some advantages. However, the downside is that PixelAtt consumes too much VRAM, easily leading to OOM (Out Of Memory) errors. But my personal implementation being suboptimal does not mean this method is not good; it might be due to my poor tuning, or perhaps the network wasn’t deep enough, etc. I personally am quite optimistic about this research into discrete encodings.
Summary (formatted)#
At this point, I have finally clarified VQ-VAE in what I believe is a good way. Looking back at the whole article, there is actually no trace of VAE. So I said it is actually an AE, an AE that encodes into discrete vectors. Its ability to reconstruct relatively clear images is due to its retention of sufficiently large feature maps.
If one understands VQ-VAE, then its new 2.0 version is not difficult to comprehend. VQ-VAE-2 has almost no fundamental technical updates compared to VQ-VAE; it simply performs encoding and decoding in two layers (one for global features, one for local features), thereby reducing the blurriness of generated images (at least significantly less, though if you look closely at large VQ-VAE-2 images, there’s still a slight blur).
However, it is worth affirming that the entire VQ-VAE model is quite interesting. Its novel features, such as discrete encoding and assigning gradients using the Straight-Through method, are well worth studying carefully, as they can deepen our understanding of deep learning models and optimization (if you can design gradients, why worry about not designing a good model?).
@online{kexuefm-6760,
title={A Concise Introduction to VQ-VAE: Vector Quantized AutoEncoder},
author={苏剑林},
year={2019},
month={06},
url={\url{https://kexue.fm/archives/6760}},
}