This is a gemini-2.5-flash
translation of a Chinese article.
It has NOT been vetted for errors. You should have the original article open in a parallel tab at all times.
In this article, we will explore a concept called “Gradient Flow”. Simply put, gradient flow connects the points in the process of finding minimums using gradient descent, forming a trajectory that changes with (virtual) time. This trajectory is called “gradient flow”. In the latter half of the article, we will focus on extending the concept of gradient flow to probability spaces, thereby forming “Wasserstein Gradient Flow”, which provides a new perspective for understanding concepts such as the continuity equation and Fokker-Planck equation.
Gradient Descent#
Suppose we want to find the minimum of a smooth function $f(\boldsymbol{x})$. A common approach is Gradient Descent, which iterates in the following format:
$$ \boldsymbol{x}_{t+1} = \boldsymbol{x}_t -\alpha \nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t) $$(Eq. 1)
If $f(\boldsymbol{x})$ is convex with respect to $\boldsymbol{x}$, gradient descent can usually find the minimum point; otherwise, it typically only converges to a “stationary point” – a point where the gradient is zero. In a relatively ideal scenario, it might converge to a local minimum point. Here, we do not strictly distinguish between local minimums and global minimums, because in deep learning, even converging to a local minimum is quite rare.
If we denote $\alpha$ as $\Delta t$ and $\boldsymbol{x}_{t+1}$ as $\boldsymbol{x}_{t+\Delta t}$, then considering the limit as $\Delta t\to 0$, equation (1) will become an ODE:
$$ \frac{d\boldsymbol{x}_t}{dt} = -\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t) $$(Eq. 2)
The trajectory $\boldsymbol{x}_t$ obtained by solving this ODE is what we call “Gradient Flow”. In other words, gradient flow is the trajectory of gradient descent in the process of finding minimums. Provided that equation (2) holds, we also have:
$$ \frac{df(\boldsymbol{x}_t)}{dt} = \left\langle\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t),\frac{d\boldsymbol{x}_t}{dt}\right\rangle = -\Vert\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)\Vert^2 \leq 0 $$This means that as long as $\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)\neq\boldsymbol{0}$, gradient descent will always move in the direction that decreases $f(\boldsymbol{x})$ when the learning rate is sufficiently small.
For more related discussions, refer to previous articles in the optimization algorithms series, such as “Optimization Algorithms from a Dynamical Systems Perspective (I): From SGD to Momentum Acceleration” and “Optimization Algorithms from a Dynamical Systems Perspective (III): A More Holistic View”.
Steepest Direction#
Why use gradient descent? A mainstream saying is that “the negative direction of the gradient is the direction of the steepest local descent.” You can find a lot of content by directly searching this phrase. This statement is not wrong, but it’s a bit imprecise because it doesn’t specify the prerequisites – the “steepest” (or “fastest”) necessarily involves quantitative comparison, and only by first determining the comparison metric can the “steepest” result be determined.
If we only care about the direction of fastest descent, the objective of gradient descent should be:
$$ \boldsymbol{x}_{t+1} = \mathop{\text{argmin}}_{\boldsymbol{x},\Vert\boldsymbol{x} - \boldsymbol{x}_t\Vert = \epsilon} f(\boldsymbol{x}) $$(Eq. 3)
Assuming a first-order approximation is sufficient, we have:
$$ \begin{aligned} f(\boldsymbol{x})&\,=f(\boldsymbol{x}_t) + \langle \nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t),\boldsymbol{x} - \boldsymbol{x}_t\rangle\\ &\, \geq f(\boldsymbol{x}_t) - \Vert\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)\Vert \Vert\boldsymbol{x} - \boldsymbol{x}_t\Vert\\ &\, = f(\boldsymbol{x}_t) - \Vert\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)\Vert \epsilon\\ \end{aligned} $$The condition for equality to hold is:
$$ \boldsymbol{x} - \boldsymbol{x}_t = -\epsilon\frac{\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)}{\Vert\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)\Vert}\quad\Rightarrow\quad\boldsymbol{x}_{t+1} = \boldsymbol{x}_t - \epsilon\frac{\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)}{\Vert\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t)\Vert} $$(Eq. 4)
As can be seen, the update direction is precisely the negative direction of the gradient, which is why it is said to be the direction of the steepest local descent. However, don’t forget that this is derived under the constraint $\Vert\boldsymbol{x} - \boldsymbol{x}_t\Vert = \epsilon$, where $\Vert\cdot\Vert$ is the Euclidean norm. If we change the definition of the norm, or simply change the constraint, the result will be different. Therefore, strictly speaking, it should be “in Euclidean space, the negative direction of the gradient is the direction of the steepest local descent.”
Optimization Perspective#
Equation (3) is a constrained optimization problem, which is more troublesome to generalize and solve. Furthermore, the solution of equation (3) is equation (4), which is also not the original gradient descent (1). In fact, it can be proven that the optimization objective corresponding to equation (1) is:
$$ \boldsymbol{x}_{t+1} = \mathop{\text{argmin}}_{\boldsymbol{x}} \frac{\Vert\boldsymbol{x} - \boldsymbol{x}_t\Vert^2}{2\alpha} + f(\boldsymbol{x}) $$(Eq. 5)
That is, by adding the constraint as a penalty term to the optimization objective, we don’t need to consider solving the constraint, and it’s easier to generalize. Moreover, even with the additional $\frac{\Vert\boldsymbol{x} - \boldsymbol{x}_t\Vert^2}{2\alpha}$, the optimization of the above equation is guaranteed not to go in a worse direction, because substituting $\boldsymbol{x} = \boldsymbol{x}_t$ clearly shows that the objective function is exactly $f(\boldsymbol{x}_t)$, so the result of $\min_{\boldsymbol{x}}$ will be at least not greater than $f(\boldsymbol{x}_t)$.
When $\alpha$ is sufficiently small, the first term dominates, so $\Vert\boldsymbol{x} - \boldsymbol{x}_t\Vert$ must be sufficiently small for the first term to become small enough. This means the optimal point should be very close to $\boldsymbol{x}_t$. Thus, we can expand $f(\boldsymbol{x})$ around $\boldsymbol{x}_t$ to get:
$$ \boldsymbol{x}_{t+1} = \mathop{\text{argmin}}_{\boldsymbol{x}} \frac{\Vert\boldsymbol{x} - \boldsymbol{x}_t\Vert^2}{2\alpha} + f(\boldsymbol{x}_t)+\langle\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t),\boldsymbol{x}-\boldsymbol{x}_t\rangle $$At this point, it is simply a quadratic function minimization problem, and the solution is precisely equation (1).
Clearly, besides the squared norm, we can also consider other regularization terms, thereby forming different gradient descent schemes. For example, Natural Gradient Descent uses KL divergence as the regularization term:
$$ \boldsymbol{x}_{t+1} = \mathop{\text{argmin}}_{\boldsymbol{x}} \frac{KL(p(\boldsymbol{y}|\boldsymbol{x})\Vert p(\boldsymbol{y}|\boldsymbol{x}_t))}{\alpha} + f(\boldsymbol{x}) $$where $p(\boldsymbol{y}|\boldsymbol{x})$ is a probability distribution related to $f(\boldsymbol{x})$. To solve the above equation, $f(\boldsymbol{x})$ is also expanded to the first order. However, KL divergence is special; its first-order expansion is zero (refer to here), so it must be expanded to at least the second order. The overall result is:
$$ \boldsymbol{x}_{t+1} = \mathop{\text{argmin}}_{\boldsymbol{x}} \frac{(\boldsymbol{x}-\boldsymbol{x}_t)^{\top}\boldsymbol{F}(\boldsymbol{x}-\boldsymbol{x}_t)}{2\alpha} + f(\boldsymbol{x}_t)+\langle\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t),\boldsymbol{x}-\boldsymbol{x}_t\rangle $$Here, $\boldsymbol{F}$ is the Fisher information matrix. The calculation details will not be elaborated on here, but the process can also be found here. Now, the above equation is essentially also a quadratic function minimization problem, and the result is:
$$ \boldsymbol{x}_{t+1} = \boldsymbol{x}_t -\alpha \boldsymbol{F}^{-1}\nabla_{\boldsymbol{x}_t}f(\boldsymbol{x}_t) $$This is what is known as “Natural Gradient Descent”.
Introduction to Functionals#
The term “functional” might sound intimidating, but in fact, it should be familiar to regular readers of this site. Simply put, an ordinary multivariate function takes a vector as input and outputs a scalar, whereas a functional takes a function as input and outputs a scalar, such as the definite integral operation:
$$ \mathcal{I}[f] = \int_a^b f(x)dx $$For any function $f$, the result of calculating $\mathcal{I}[f]$ is a scalar, so $\mathcal{I}[f]$ is a functional. Another example is the KL divergence mentioned earlier, which is defined as:
$$ KL(p\Vert q) = \int p(\boldsymbol{x})\log \frac{p(\boldsymbol{x})}{q(\boldsymbol{x})}d\boldsymbol{x} $$Here, the integral is implicitly over the entire space. If $p(\boldsymbol{x})$ is fixed, then it is a functional with respect to $q(\boldsymbol{x})$, because $q(\boldsymbol{x})$ is a function, and inputting a function that satisfies the conditions will result in $KL(p\Vert q)$ outputting a scalar. More generally, the $f$-divergence introduced in “Introduction to f-GAN: The GAN Model Workshop” is also a type of functional. These are relatively simple functionals; more complex ones may include derivatives of the input function, such as the principle of least action in theoretical physics.
Below, we will primarily focus on functionals whose domain is the set of all probability density functions, i.e., functionals that take a probability density as input and output a scalar.
Flow of Probability#
Suppose we have a functional $\mathcal{F}[q]$ and want to compute its minimum. Following the idea of gradient descent, as long as we can find some form of its gradient, we can iterate along its negative direction.
To determine the iteration format, following our previous line of thought, let’s consider generalizing equation (5). Here, $f(\boldsymbol{x})$ is naturally replaced by $\mathcal{F}[q]$. So, what should the first regularization term be replaced with? In equation (5), it’s the squared Euclidean distance. It’s natural to think that here it should also be replaced by some squared distance. For probability distributions, a well-behaved distance is the Wasserstein distance (specifically, the “2-Wasserstein distance”):
$$ \mathcal{W}_2[p,q]=\sqrt{\inf_{\gamma\in \Pi[p,q]} \iint \gamma(\boldsymbol{x},\boldsymbol{y}) \Vert\boldsymbol{x}-\boldsymbol{y}\Vert^2 d\boldsymbol{x}d\boldsymbol{y}} $$We will not go into detailed introductions about it here. Interested readers can refer to “From Wasserstein Distance and Duality Theory to WGAN”. If we further replace the Euclidean distance in equation (5) with the Wasserstein distance, the final objective becomes:
$$ q_{t+1} = \mathop{\text{argmin}}_{q} \frac{\mathcal{W}_2^2[q,q_t]}{2\alpha} + \mathcal{F}[q] $$Apologies, but I cannot concisely provide the solution process for the above objective, nor do I fully understand its solution process myself. Based on literature such as “Introduction to Gradient Flows in the 2-Wasserstein Space” and “{ Euclidean, Metric, and Wasserstein } Gradient Flows: an overview”, I will directly present its solution result as:
$$ q_{t+1}(\boldsymbol{x}) = q_t(\boldsymbol{x}) + \alpha \nabla_{\boldsymbol{x}}\cdot\left(q_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\frac{\delta \mathcal{F}[q_t(\boldsymbol{x})]}{\delta q_t(\boldsymbol{x})}\right) $$Or, taking the limit, we get:
$$ \frac{\partial q_t(\boldsymbol{x})}{\partial t} = \nabla_{\boldsymbol{x}}\cdot\left(q_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\frac{\delta \mathcal{F}[q_t(\boldsymbol{x})]}{\delta q_t(\boldsymbol{x})}\right) $$This is the “Wasserstein Gradient Flow”, where $\frac{\delta \mathcal{F}[q]}{\delta q}$ is the variational derivative of $\mathcal{F}[q]$. For a definite integral functional, the variational derivative is the derivative of the integrand:
$$ \mathcal{F}[q] = \int F(q(\boldsymbol{x}))d\boldsymbol{x} \quad\Rightarrow\quad \frac{\delta \mathcal{F}[q(\boldsymbol{x})]}{\delta q(\boldsymbol{x})} = \frac{\partial F(q(\boldsymbol{x}))}{\partial q(\boldsymbol{x})} $$Some Examples#
According to “Introduction to f-GAN: The GAN Model Workshop”, the $f$-divergence is defined as:
$$ \mathcal{D}_f(p\Vert q) = \int q(\boldsymbol{x}) f\left(\frac{p(\boldsymbol{x})}{q(\boldsymbol{x})}\right)d\boldsymbol{x} $$Fixing $p$, and letting $\mathcal{F}[q]=\mathcal{D}_f(p\Vert q)$, we obtain:
$$ \frac{\partial q_t(\boldsymbol{x})}{\partial t} = \nabla_{\boldsymbol{x}}\cdot\Big(q_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\big(f(r_t(\boldsymbol{x})) - r_t(\boldsymbol{x}) f'(r_t(\boldsymbol{x}))\big)\Big) $$(Eq. 6)
where $r_t(\boldsymbol{x}) = \frac{p(\boldsymbol{x})}{q_t(\boldsymbol{x})}$. According to “Deriving the Continuity Equation and Fokker-Planck Equation Using the Test Function Method”, the above equation has the form of a continuity equation, so via the ODE:
$$ \frac{d\boldsymbol{x}}{dt} = -\nabla_{\boldsymbol{x}}\big(f(r_t(\boldsymbol{x})) - r_t(\boldsymbol{x}) f'(r_t(\boldsymbol{x}))\big) $$Sampling from distribution $q_t$ can be achieved. And based on the previous discussion, equation (6) is the Wasserstein gradient flow that minimizes the $f$-divergence between $p$ and $q$. When $t\to\infty$, the $f$-divergence becomes zero, i.e., $q_t=p$. Therefore, as $t\to\infty$, the above ODE achieves sampling from distribution $p$. However, this result currently only has formal significance and no practical utility, because it means we need to know the expression for distribution $p$, and also solve for the expression of $q_t$ from equation (6) to then compute the right-hand side of the ODE, thereby completing the sampling. This computational difficulty is very high and usually unachievable.
A relatively simple example is the (reverse) KL divergence, where $f=-\log$. Substituting this into equation (6) yields:
$$ \begin{aligned} \frac{\partial q_t(\boldsymbol{x})}{\partial t} =& - \nabla_{\boldsymbol{x}}\cdot\left(q_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\log \frac{p(\boldsymbol{x})}{q_t(\boldsymbol{x})}\right)\\ =& - \nabla_{\boldsymbol{x}}\cdot\Big(q_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\big(\log p(\boldsymbol{x}) - \log q_t(\boldsymbol{x})\big)\Big)\\ =& - \nabla_{\boldsymbol{x}}\cdot\big(q_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\log p(\boldsymbol{x})\big) + \nabla_{\boldsymbol{x}}\cdot\nabla_{\boldsymbol{x}} q_t(\boldsymbol{x}) \end{aligned} $$Comparing again with the results from “Deriving the Continuity Equation and Fokker-Planck Equation Using the Test Function Method”, this is precisely a Fokker-Planck equation, corresponding to the SDE:
$$ d\boldsymbol{x} = \nabla_{\boldsymbol{x}}\log p(\boldsymbol{x}) dt + \sqrt{2}dw $$That is, if we know $\log p(\boldsymbol{x})$, we can use the above equation to sample from $p(\boldsymbol{x})$. Compared to the previous ODE, this scheme avoids the process of solving for $q_t(\boldsymbol{x})$ and is a relatively feasible approach.
Summary (formatted)#
This article introduced the concept of “gradient flow” in the process of finding minimums via gradient descent, including its extension from vector spaces to probability spaces (Wasserstein gradient flow), and their connections with the continuity equation, Fokker-Planck equation, and ODE/SDE sampling.
@online{kexuefm-9660,
title={Gradient Flow: Exploring the Path to Minimums},
author={苏剑林},
year={2023},
month={06},
url={\url{https://kexue.fm/archives/9660}},
}