This is a gemini-2.5-flash
translation of a Chinese article.
It has NOT been vetted for errors. You should have the original article open in a parallel tab at all times.
As mentioned last time, when we shifted the optimization object from vector parameters to matrix parameters and adopted the spectral norm constraint, which is more suitable for matrices, the Muon optimizer naturally emerged. Furthermore, we considered the steepest descent direction after adding an orthogonal constraint to the parameters. This was discussed in two parts: square matrices and non-square matrices. The solution for square matrices was completed in our previous article, but the non-square part remained unresolved.
The goal of this article is to complete the solution for the non-square part, thereby fully addressing optimization under orthogonal constraints.
Task Information#
Let’s briefly review the results from the previous article, Steepest Descent on Manifolds: 2. Muon + Orthogonal. Our objective is to solve:
$$ \max_{\boldsymbol{\Phi}} \tr(\boldsymbol{G}^{\top}\boldsymbol{\Phi}) \qquad \text{s.t.}\qquad \Vert\boldsymbol{\Phi}\Vert_2 = 1,\,\, \boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I},\,\,(\boldsymbol{W} - \eta \boldsymbol{\Phi})^{\top}(\boldsymbol{W} - \eta \boldsymbol{\Phi})=\boldsymbol{I} $$where $\boldsymbol{W},\boldsymbol{\Phi}\in\mathbb{R}^{n\times m}(n \geq m)$, and $\Vert\cdot\Vert_2$ is the spectral norm. Based on the principle that “first-order approximation is sufficient,” this can be simplified to:
$$ \max_{\boldsymbol{\Phi}} \tr(\boldsymbol{G}^{\top}\boldsymbol{\Phi}) \qquad \text{s.t.}\qquad \Vert\boldsymbol{\Phi}\Vert_2 = 1,\,\, \boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I},\,\,\boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W} = \boldsymbol{0} $$Here, the set of all $\boldsymbol{\Phi}$ satisfying $\boldsymbol{W}^{\top}\boldsymbol{\Phi}+\boldsymbol{\Phi}^{\top}\boldsymbol{W} = \boldsymbol{0}$ is also called the “tangent space” of $\boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I}$. In the previous article, we already derived the general solution form:
$$ \boldsymbol{\Phi} = \msign(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}) $$where $\boldsymbol{X}\in\mathbb{R}^{m\times m}$ is an undetermined symmetric matrix.
The remaining challenge is to provide a method for calculating the symmetric matrix $\boldsymbol{X}$ such that $\boldsymbol{W}^{\top}\boldsymbol{\Phi}$ is an anti-symmetric matrix. Once the solution is complete, the corresponding $\boldsymbol{\Phi}$ is naturally the optimal solution. For $n=m$, we have already found the closed-form solution $\boldsymbol{X}=-[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}$; the truly difficult case is $n > m$, which is also known as the “Stiefel manifold,” and it is precisely the Open problem left in Orthogonal manifold.
Equation Transformation#
In short, our current task is to solve the equation system:
$$ \boldsymbol{W}^{\top}\msign(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})+\msign(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})^{\top}\boldsymbol{W} = \boldsymbol{0} $$When $n=m$, $\boldsymbol{W}^{\top}$ can be directly absorbed into $\msign$. This simplifies the solution; however, for $n > m$, such absorption is not possible, which is where the difficulty lies. I lean towards there being no simple explicit solution when $n > m$, so we will seek a numerical algorithm.
According to the definition $\msign(\boldsymbol{M})=\boldsymbol{M}(\boldsymbol{M}^{\top}\boldsymbol{M})^{-1/2}$, we can write:
$$ \boldsymbol{W}^{\top}\msign(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}) = \boldsymbol{W}^{\top}(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})\boldsymbol{Q}^{-1} = (\boldsymbol{W}^{\top}\boldsymbol{G} + \boldsymbol{X})\boldsymbol{Q}^{-1} $$where $\boldsymbol{Q} = ((\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})^{\top}(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}))^{1/2}$. Under this new notation, the equation system becomes:
$$ (\boldsymbol{W}^{\top}\boldsymbol{G} + \boldsymbol{X})\boldsymbol{Q}^{-1} + \boldsymbol{Q}^{-1}(\boldsymbol{G}^{\top}\boldsymbol{W} + \boldsymbol{X}) = \boldsymbol{0} $$Multiplying by $\boldsymbol{Q}$ from the left and right simultaneously, we get:
$$ \boldsymbol{Q}(\boldsymbol{W}^{\top}\boldsymbol{G} + \boldsymbol{X}) + (\boldsymbol{G}^{\top}\boldsymbol{W} + \boldsymbol{X})\boldsymbol{Q} = \boldsymbol{0} $$Furthermore, $\boldsymbol{Q}$ also holds:
$$ \boldsymbol{Q} = (\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})^{\top}\msign(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}) $$Iterative Solution#
My idea now is to start with an initial value for $\boldsymbol{X}$, substitute it into equation (7) to get $\boldsymbol{Q}$, then substitute $\boldsymbol{Q}$ into equation (6) to solve for a new $\boldsymbol{X}$, and repeat this iteration until convergence. With $\msign$ known, equation (7) can be computed explicitly, so the only difficulty is solving equation (6).
We can rearrange equation (6):
$$ \boldsymbol{Q}\boldsymbol{X} + \boldsymbol{X}\boldsymbol{Q} = -2[\boldsymbol{Q}\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}} $$Given $\boldsymbol{Q}$, this is actually a linear equation system for $\boldsymbol{X}$, called a “continuous Lyapunov equation,” and can also be seen as a special case of a “Sylvester equation.” If we only use CPU for computation, Scipy already comes with a solver function for this equation, scipy.linalg.solve_continuous_lyapunov
, which can be called directly.
Regarding the choice of initial value, we can consider the solution for the square matrix case, $-[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}$, which is a natural transition from square to non-square matrices. We can also observe the reasonableness of the initial value $-[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}$ from another equivalent form of equation (8):
$$ \boldsymbol{Q}(\boldsymbol{X} + [\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}) + (\boldsymbol{X} + [\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}})\boldsymbol{Q} =[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{skew}}\boldsymbol{Q} -\boldsymbol{Q}[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{skew}} $$Therefore, the accuracy of $-[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}$ depends on how commutative the multiplication of $[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{skew}}$ and $\boldsymbol{Q}$ is; the closer they are to commuting matrices, the more accurate $-[\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}$ will be. However, subsequent empirical results show that our iterative algorithm is not particularly sensitive to the initial value; even using a zero matrix as the initial value does not pose a big problem.
Do It Yourself#
We just mentioned that Scipy comes with a function to solve the Lyapunov equation, so it can be called directly without needing to worry about the solving process. However, this is limited to Scipy on CPU. I checked, and neither PyTorch nor JAX have similar functions, so for GPU computation, one has to “do it oneself.”
There are two approaches to programmatically solve equation (8). One is to follow the idea from What can the matrix sign function mcsgn compute? and use $\mcsgn$ (not $\msign$) to solve it:
$$ \boldsymbol{X} = \mcsgn\left(\begin{bmatrix}-\boldsymbol{Q} & -[\boldsymbol{Q}\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}} \\ \boldsymbol{0} & \boldsymbol{Q}\end{bmatrix}\right)_{[:m,m:]} $$The second is based on SVD. We have already used this method when calculating the gradient of $\msign$ in The Derivative of msign; here, we will reintroduce it in conjunction with equation (8). According to the definition of $\boldsymbol{Q}$, it is positive definite and symmetric. Thus, it can be decomposed by eigenvalue decomposition as $\boldsymbol{V}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}$, where $\boldsymbol{V}$ is an orthogonal matrix and $\boldsymbol{\Sigma}=\mathop{\text{diag}}(\sigma_1,\cdots,\sigma_m)$ is a diagonal matrix. Substituting this into equation (8), we can rearrange to get:
$$ \boldsymbol{\Sigma}(\boldsymbol{V}^{\top}\boldsymbol{X}\boldsymbol{V}) + (\boldsymbol{V}^{\top}\boldsymbol{X}\boldsymbol{V})\boldsymbol{\Sigma} = -2\boldsymbol{V}^{\top}[\boldsymbol{Q}\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}\boldsymbol{V} $$The left side can be expressed as $(\boldsymbol{V}^{\top}\boldsymbol{X}\boldsymbol{V})\otimes \boldsymbol{S}$, where $\otimes$ is the Hadamard product, and $\boldsymbol{S}_{i,j} = \sigma_i + \sigma_j$. From this, we can solve for:
$$ \boldsymbol{X} = -2\boldsymbol{V}((\boldsymbol{V}^{\top}[\boldsymbol{Q}\boldsymbol{W}^{\top}\boldsymbol{G}]_{\text{sym}}\boldsymbol{V})\oslash \boldsymbol{S})\boldsymbol{V}^{\top} $$where $\oslash$ is the Hadamard division. The interesting point here is that performing eigenvalue decomposition on $\boldsymbol{Q}$ is essentially equivalent to performing SVD on $\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}$. Furthermore, SVD on $\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}$ can also be used to calculate $\msign(\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X})$. Therefore, a single SVD operation can yield both $\msign$ and the solution to equation (8).
Both approaches have their own characteristics. Approach one requires calculating $\msign$ for an $m\times m$ matrix first, and then $\mcsgn$ for a $2m\times 2m$ matrix. Although both can be efficiently computed using Newton-Schulz iteration, the computational cost is not negligible. Additionally, we must choose coefficients that ensure convergence and high precision (refer to the results in Newton-Schulz Iteration for the msign operator (Part 2)), otherwise neither $\mcsgn$ nor $\msign$ computations will converge, let alone $\boldsymbol{X}$.
Approach two requires SVD. Although SVD has higher complexity and often requires forced FP32 precision, for this problem, each iteration only needs one SVD to compute both $\msign$ and $\boldsymbol{X}$ simultaneously, so the overall efficiency won’t be too bad. If we don’t need too many matrix parameters with orthogonal constraints, SVD might be the simplest choice.
Related Results#
Before this article, @leloy also proposed two heuristic solution methods for the original objective (1) in his blog post Heuristic Solutions for Steepest Descent on the Stiefel Manifold. Here, “heuristic” means that in most cases, it can yield a decent solution, but it cannot guarantee optimality. Let’s learn about them together.
The first method can be called a purely geometric approach. First, let’s define the projection operation:
$$ \mathcal{P}_{\boldsymbol{W}}(\boldsymbol{M}) = \boldsymbol{M} - \boldsymbol{W}[\boldsymbol{W}^{\top}\boldsymbol{M}]_{\text{sym}} $$It can be verified that $\boldsymbol{W}^{\top}\mathcal{P}_{\boldsymbol{W}}(\boldsymbol{M})$ is always an anti-symmetric matrix, meaning $\mathcal{P}_{\boldsymbol{W}}(\boldsymbol{M})$ is always in the tangent space. Therefore, we consider it a projection operation of any matrix $\boldsymbol{M}$ onto the tangent space of $\boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I}$.
Starting from the gradient $\boldsymbol{G}$, $\mathcal{P}_{\boldsymbol{W}}(\boldsymbol{M})$ is certainly in the tangent space. However, we know that Muon’s update quantity must be an orthogonal matrix (when full rank), but $\mathcal{P}_{\boldsymbol{W}}(\boldsymbol{M})$ is not necessarily orthogonal. Therefore, we can use $\msign$ to find the closest orthogonal matrix to it, i.e., $\msign(\mathcal{P}_{\boldsymbol{W}}(\boldsymbol{M}))$. However, after applying $\msign$, it may not be in the tangent space anymore. We can then project it back to the tangent space, and then again find the nearest orthogonal matrix, iterating repeatedly:
$$ \boldsymbol{\Phi} = (\msign\circ\mathcal{P}_{\boldsymbol{W}}\circ\cdots\circ\msign\circ\mathcal{P}_{\boldsymbol{W}})(\boldsymbol{M}) $$This is @leloy’s first approach: alternating projections onto the tangent space and orthogonal space until convergence, which can be said to be quite intuitive. Moreover, in relatively random cases, it is very close to the optimal solution, even accurate to 4 decimal places, so much so that I initially thought it was the exact solution. However, after further search, I found cases where it deviates sufficiently from the optimal solution, confirming that it was merely a coincidence and not the optimal solution.
The second method can be called line search. Specifically, when $n > m$, we can consider extending $\boldsymbol{W}$ into a standard $n\times n$ orthogonal matrix $[\boldsymbol{W},\overline{\boldsymbol{W}}]$, and then decompose the desired $\boldsymbol{\Phi}$ into two parts: $\boldsymbol{W}^{\top}\boldsymbol{\Phi}$ and $\overline{\boldsymbol{W}}{}^{\top}\boldsymbol{\Phi}$. Then @leloy performed a greedy approximation: first solving for the optimal $\boldsymbol{W}^{\top}\boldsymbol{\Phi}$, then for the optimal $\overline{\boldsymbol{W}}{}^{\top}\boldsymbol{\Phi}$, and finally introducing a line search between the two to improve accuracy.
This set of operations indeed yields a solution with a good approximation, and it is guaranteed to be within the tangent space and satisfy orthogonality. The solving process requires computing the spectral norm, $\msign$, and Cholesky decomposition; for details, please refer to the author’s article. Furthermore, when $m=2$, it is theoretically possible to find the optimal solution because a $2\times 2$ anti-symmetric matrix has only one free parameter, and the line search provides exactly one degree of freedom.
Let’s Test It#
Below, we empirically test the above methods in Numpy. The main purpose is to verify the correctness of the methods themselves, so we directly implement $\msign$ and $\mcsgn$ using Singular Value Decomposition and Eigenvalue Decomposition.
import numpy as np
import scipy as sp
def mcsgn(x):
"""特征值分解精确计算mcsgn
"""
s, v = np.linalg.eig(x)
return v @ np.diag(np.sign(s)) @ np.linalg.inv(v)
def msign(g):
"""奇异值分解精确计算msign
"""
u, s, vh = np.linalg.svd(g, full_matrices=False)
return u @ np.diag(np.sign(s)) @ vh
def sym(x):
"""对称化
"""
return (x + x.T) * 0.5
def skew(x):
"""反对称化
"""
return (x - x.T) * 0.5
def proj(g, w):
"""投影到正交的切空间
"""
return g - w @ sym(w.T @ g)
def jianlin_by_mcsgn(g, w, steps=20):
"""通过mcsgn来构建本文的迭代
"""
n, m = g.shape
x = -sym(w.T @ g)
for i in range(1, steps + 1):
phi = msign(z := g + w @ x)
print('step:', i, ', inner product:', (phi * g).sum(), ', tangent error:', np.abs(sym(w.T @ phi)).mean())
if i == steps:
return phi
q = z.T @ phi
x = mcsgn(np.block([[-q, -sym(q @ w.T @ g)], [np.zeros_like(q), q]]))[:m, m:]
# x = -2 * sp.linalg.solve_continuous_lyapunov(q, sym(q @ w.T @ g))
def jianlin_by_svd(g, w, steps=20):
"""通过svd来构建本文的迭代
"""
x = -sym(w.T @ g)
for i in range(1, steps + 1):
u, s, vh = np.linalg.svd(z := g + w @ x, full_matrices=False)
phi = (u * np.sign(s)) @ vh
print('step:', i, ', inner product:', (phi * g).sum(), ', tangent error:', np.abs(sym(w.T @ phi)).mean())
if i == steps:
return phi
x = -2 * vh.T @ (vh @ sym(z.T @ phi @ w.T @ g) @ vh.T / (s + s[:, None])) @ vh
def leloy_v1(g, w, steps=20):
"""交替投影到切空间和正交空间
"""
phi = g
for i in range(1, steps + 1):
phi = msign(proj(phi, w))
print('step:', i, ', inner product:', (phi * g).sum(), ', tangent error:', np.abs(sym(w.T @ phi)).mean())
return phi
def leloy_v2(g, w, steps=20):
"""分部贪心求解 + 线搜索(形式经过笔者的简化)
"""
n, m = g.shape
taus = np.linspace(0, 1, steps + 2)[1:-1]
p_max, tau_opt, phi_opt = 0, 0, None
for tau in taus:
b = (b := skew(w.T @ g)) * tau / max(np.linalg.norm(b, ord=2), 1e-8)
r = np.linalg.cholesky(np.eye(m) - b.T @ b)
c = msign((np.eye(n) - w @ w.T) @ g @ r) @ r
phi = w @ b + c
print('tau:', tau, ', inner product:', p := (phi * g).sum())
if p > p_max:
p_max, tau_opt, phi_opt = p, tau, phi
print('best inner product:', p_max, ', tau:', tau_opt)
return phi_opt
w = np.array([[ 0.69453734, -0.26590866, -0.44721806, 0.2753041 ],
[-0.11738148, -0.5588003 , -0.17580748, 0.3218624 ],
[-0.4515288 , -0.23489913, -0.26683152, -0.25739142],
[ 0.02392521, 0.02664689, 0.48423648, 0.6193399 ],
[ 0.45194831, -0.25206333, 0.27654836, -0.60242337],
[ 0.21197332, -0.09174792, 0.24521762, -0.08484317],
[-0.15496767, -0.26446804, -0.34942415, -0.01877318],
[-0.16181251, -0.6474956 , 0.45243263, -0.01776086]])
g = np.array([[-17.85745 , -10.758921 , -2.9583392 , 6.245008 ],
[-28.883093 , 19.772121 , 8.086545 , -21.564013 ],
[ -1.6274693 , -14.96859 , 3.4465332 , 3.1070817 ],
[ -7.8890743 , 1.5304767 , -8.949573 , 9.579629 ],
[ 2.246596 , 14.46572 , 12.8451 , -2.7370298 ],
[ -0.9496974 , 6.9879804 , 2.849277 , 1.1148484 ],
[ -8.115278 , -18.054405 , -0.19287404, 7.0389237 ],
[-15.062008 , -15.02901 , 2.9083247 , 21.706533 ]])
phi1 = jianlin_by_mcsgn(g, w, steps=100)
phi2 = jianlin_by_svd(g, w, steps=100)
phi3 = leloy_v1(g, w, steps=100)
phi4 = leloy_v2(g, w, steps=100)
assert np.allclose(phi1, phi2)
w = np.linalg.qr(np.random.randn(100, 50))[0]
g = np.random.randn(100, 50)
phi1 = jianlin_by_mcsgn(g, w, steps=10)
phi2 = jianlin_by_svd(g, w, steps=10)
phi3 = leloy_v1(g, w, steps=10)
phi4 = leloy_v2(g, w, steps=10)
assert np.allclose(phi1, phi2)
For the first set of $\boldsymbol{W},\boldsymbol{G}$ given in the code, my method yields an optimal $\tr(\boldsymbol{G}^{\top} \boldsymbol{\Phi})$ of approximately $90$, and the results from $\mcsgn$ and SVD are identical. In contrast, @leloy’s first method yields approximately $70$, and his second method yields approximately $80$, both showing a certain deviation from the optimal solution.
However, the first set of $\boldsymbol{W},\boldsymbol{G}$ was specifically chosen as an extreme example to highlight the differences among the three methods. If we use relatively random values, then my solution and @leloy’s first solution will be very close, and the number of iterations can be much smaller (5-10 steps). In such cases, @leloy’s second solution will show a larger deviation from the optimal solution. Readers can construct their own examples to test this.
Further Thoughts#
Regarding the solution to the original problem (1), this concludes the discussion for now. Next, let’s discuss a few detailed issues that might cause confusion.
First, for convenience of description, the iterative solving process I presented earlier has an implicit assumption: that $\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}$ is always full rank (rank $m$). Otherwise, the matrix $\boldsymbol{S}$ would have zero components, making $\oslash\boldsymbol{S}$ difficult to operate. However, this difficulty is not fundamental, because equation (3) must have a solution. Therefore, when the denominator is zero, the numerator must also be zero. Thus, we only need to simply replace the zero components of $\boldsymbol{S}$ with a small positive number to obtain the correct result.
From a numerical computation perspective, we rarely encounter singular values that are exactly zero, so there’s no need to worry too much about this issue; assuming $\boldsymbol{G} + \boldsymbol{W}\boldsymbol{X}$ is full rank by default is fine. Under this default assumption, the retraction operation becomes very simple, because:
$$ (\boldsymbol{W} - \eta\boldsymbol{\Phi})^{\top}(\boldsymbol{W} - \eta\boldsymbol{\Phi}) = \boldsymbol{W}^{\top} \boldsymbol{W} - \eta(\boldsymbol{W}^{\top} \boldsymbol{\Phi} + \boldsymbol{\Phi}^{\top}\boldsymbol{W}) + \eta^2 \boldsymbol{\Phi}^{\top}\boldsymbol{\Phi} $$According to the definition of the Stiefel manifold, the first term on the right-hand side is $\boldsymbol{I}$. According to the tangent space condition, the second term is $\boldsymbol{0}$. Finally, when full rank, $\msign$ produces a matrix that is also on the Stiefel manifold, so the third term is $\eta^2 \boldsymbol{I}$. The total result is $(1+\eta^2)\boldsymbol{I}$, and retraction can be achieved simply by dividing by $\sqrt{1+\eta^2}$:
$$ \boldsymbol{W}\quad\leftarrow\quad\frac{\boldsymbol{W} - \eta\boldsymbol{\Phi}}{\sqrt{1+\eta^2}} $$Having read this far, I wonder if you have noticed a deeper issue here: for both the relatively simple orthogonal manifold and the more complex Stiefel manifold, what precision should we use for calculations? It is important to know that “orthogonality” is a precise quantitative constraint. $\boldsymbol{W}^{\top}\boldsymbol{W}=\boldsymbol{I}$ involves $m(m+1)/2$ equality constraints. It can be foreseen that iterating using the above equations with low precision will inevitably lead to significant deviation from orthogonality over time, not to mention errors in the process of solving for $\boldsymbol{\Phi}$.
Therefore, I believe that unless we periodically apply an orthogonalization operation (i.e., $\boldsymbol{W}\leftarrow\msign(\boldsymbol{W})$) to the parameters to pull them back onto the orthogonal manifold, the computational precision for the solving process should at least be FP32. Considering that there usually aren’t many parameters requiring orthogonal constraints, this generally doesn’t incur a significant cost.
Summary (formatted)#
This article extends the “Muon + Orthogonal Manifold” from the previous article to the more general “Muon + Stiefel Manifold,” primarily discovering an iterative algorithm for solving the corresponding update quantity.
@online{kexuefm-11221,
title={Steepest Descent on Manifolds: 3. Muon + Stiefel},
author={苏剑林},
year={2025},
month={08},
url={\url{https://kexue.fm/archives/11221}},
}