This is a gemini-2.5-flash
translation of a Chinese article.
It has NOT been vetted for errors. You should have the original article open in a parallel tab at all times.
By Su Jianlin | 2025-05-11 | 21,697 Readers
In previous articles such as “Appreciation of Muon Optimizer: An Essential Leap from Vectors to Matrices” and “Muon Sequel: Why We Chose to Try Muon?”, we introduced a highly promising emerging optimizer, “Muon”, which has the potential to replace Adam. As related research continues to deepen, the Muon optimizer is receiving increasing attention.
Readers familiar with Muon know that its core operation is the $\msign$ operator, and finding more efficient computation methods for it is an ongoing goal of the academic community. This article will summarize its latest progress.
Preamble#
The definition of $\msign$ is closely related to SVD. Given a matrix $\boldsymbol{M}\in\mathbb{R}^{n\times m}$, then
$$ \boldsymbol{U},\boldsymbol{\Sigma},\boldsymbol{V}^{\top} = \text{SVD}(\boldsymbol{M}) \quad\Rightarrow\quad \msign(\boldsymbol{M}) = \boldsymbol{U}_{[:,:r]}\boldsymbol{V}_{[:,:r]}^{\top} $$where $\boldsymbol{U}\in\mathbb{R}^{n\times n},\boldsymbol{\Sigma}\in\mathbb{R}^{n\times m},\boldsymbol{V}\in\mathbb{R}^{m\times m}$, and $r$ is the rank of $\boldsymbol{M}$. Simply put, $\msign$ is a new matrix obtained by changing all non-zero singular values of the matrix to 1. Based on SVD, we can also prove that
$$ \text{msign}(\boldsymbol{M}) = (\boldsymbol{M}\boldsymbol{M}^{\top})^{-1/2}\boldsymbol{M}= \boldsymbol{M}(\boldsymbol{M}^{\top}\boldsymbol{M})^{-1/2} $$Here, $^{-1/2}$ denotes the $-1/2$ power of a matrix. This form is very similar to the scalar $\mathop{\text{sign}}(x) = x / \sqrt{x^2}$, which is why the author used the name $\msign$. However, it’s important to note that this is not entirely the same as the Wikipedia “Matrix Sign”; the Wikipedia concept applies only to square matrices, but they are consistent when $\boldsymbol{M}$ is a symmetric matrix.
When $m=n=r$, $\text{msign}(\boldsymbol{M})$ also has the meaning of “optimal orthogonal approximation”:
$$ \text{msign}(\boldsymbol{M}) = \mathop{\text{argmin}}_{\boldsymbol{O}^{\top}\boldsymbol{O} = \boldsymbol{I}}\Vert \boldsymbol{M} - \boldsymbol{O}\Vert_F^2 $$The proof can be found in “Appreciation of Muon Optimizer: An Essential Leap from Vectors to Matrices”. Because of this property, $\msign$ is also known as “symmetric orthogonalization”, a name first introduced in “On the Nonorthogonality Problem” (refer to the Wikipedia entry on “Orthogonalization”).
Finally, in “Higher-order muP: A Simpler but More Sophisticated Spectral Conditioning Scaling”, the author also regarded $\msign$ as the limit version of “singular value clipping”.
Iterative Computation#
$\msign$ is defined by SVD, and naturally, it can be precisely computed directly using SVD. However, exact SVD computation has high complexity, so in practice, “Newton-Schulz iteration” is often used for approximation.
Newton-Schulz iteration is a commonly used iterative algorithm for computing matrix functions. For $\msign$, its iterative format is
$$ \boldsymbol{X}_0 = \frac{\boldsymbol{M}}{\Vert\boldsymbol{M}\Vert_F},\qquad \boldsymbol{X}_{t+1} = a\boldsymbol{X}_t + b\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t) + c\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t)^2+\cdots $$where $\Vert\boldsymbol{M}\Vert_F$ is the Frobenius norm of $\boldsymbol{M}$, which is the square root of the sum of squares of all its elements. $(a,b,c,\cdots)$ are undetermined coefficients. In practical computation, we need to truncate to a finite number of terms, commonly 2 or 3 terms, i.e., one of the following:
$$ \boldsymbol{X}_{t+1} = a\boldsymbol{X}_t + b\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t) $$$$ \boldsymbol{X}_{t+1} = a\boldsymbol{X}_t + b\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t) + c\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t)^2 $$Finally, the $\boldsymbol{X}_T$ after $T$ iterations is returned as an approximation of $\msign(\boldsymbol{M})$. Thus, the coefficients $(a,b,c)$ and the number of iterations $T$ constitute all the hyperparameters of the Newton-Schulz iteration. The reference choice given by Muon author KellerJordan is
$$ (a,b,c)=(3.4445, -4.7750, 2.0315),\qquad T = 5 $$Our next topic is to understand it and then try to improve it.
Reference Implementation#
Here is a minimalistic reference implementation:
def msign(x, steps=5, eps=1e-20):
a, b, c, y = 3.4445, -4.7750, 2.0315, x.astype('bfloat16')
y = y.mT if x.shape[-2] > x.shape[-1] else y
y /= ((y**2).sum(axis=(-2, -1), keepdims=True) + eps)**0.5
for _ in range(steps):
y = a * y + (b * (y2 := y @ y.mT) + c * y2 @ y2) @ y
return y.mT if x.shape[-2] > x.shape[-1] else y
This implementation already includes batch processing capability (msign is applied only to the last two dimensions), and it can run in Jax; if x.astype('bfloat16')
is changed to x.to(torch.bfloat16)
, it can run in Torch; directly changing x.astype('bfloat16')
to x
also allows it to run in Numpy.
Principle Analysis#
To understand the principle of Newton-Schulz iteration, we will analyze its steps one by one. First, for $\boldsymbol{X}_0 = \boldsymbol{M}/\Vert\boldsymbol{M}\Vert_F$, we substitute the SVD of $\boldsymbol{M}$:
$$ \boldsymbol{X}_0 = \frac{\boldsymbol{M}}{\Vert\boldsymbol{M}\Vert_F} = \boldsymbol{U}_{[:,:r]}\left(\frac{\boldsymbol{\Sigma}_{[:r,:r]}}{\Vert\boldsymbol{M}\Vert_F}\right)\boldsymbol{V}_{[:,:r]}^{\top} = \boldsymbol{U}_{[:,:r]}\underbrace{\left(\frac{\boldsymbol{\Sigma}_{[:r,:r]}}{\Vert\boldsymbol{\Sigma}_{[:r,:r]}\Vert_F}\right)}_{\boldsymbol{S}_0}\boldsymbol{V}_{[:,:r]}^{\top} $$The last equality holds because the square of the Frobenius norm is equal to both the sum of squares of all elements and the sum of squares of all singular values. The final result shows that $\boldsymbol{S}_0$ is a diagonal matrix whose elements are all within $[0,1]$. In other words, all singular values of $\boldsymbol{X}_0=\boldsymbol{U}_{[:,:r]}\boldsymbol{S}_0\boldsymbol{V}_{[:,:r]}^{\top}$ do not exceed 1. This is the purpose of the first step $\boldsymbol{X}_0 = \boldsymbol{M}/\Vert\boldsymbol{M}\Vert_F$.
Next, substituting $\boldsymbol{U}_{[:,:r]}\boldsymbol{S}_t\boldsymbol{V}_{[:,:r]}^{\top}$ into the equation, we get
$$ \boldsymbol{X}_{t+1} = \boldsymbol{U}_{[:,:r]}\left(a\boldsymbol{S}_t + b\boldsymbol{S}_t^3 + c\boldsymbol{S}_t^5\right)\boldsymbol{V}_{[:,:r]}^{\top} $$This means that the iteration does not change $\boldsymbol{U}_{[:,:r]}$ and $\boldsymbol{V}_{[:,:r]}^{\top}$ on the left and right; essentially, it’s an iteration on the diagonal matrix:
$$ \boldsymbol{S}_{t+1} = a\boldsymbol{S}_t + b\boldsymbol{S}_t^3 + c\boldsymbol{S}_t^5 $$And the power of a diagonal matrix is equivalent to taking the power of each diagonal element, so this is essentially equivalent to the scalar iteration for $x_t$:
$$ x_{t+1} = a x_t + b x_t^3 + c x_t^5 $$Since $\boldsymbol{X}_0 = \boldsymbol{M}/\Vert\boldsymbol{M}\Vert_F$ has already compressed all singular values into $(0,1]$, we hope that starting from any $x_0\in(0,1]$, after $T$ iterations, $x_T$ can be as close as possible to 1, so the iteration can sufficiently approximate $\msign$. In this way, we simplify the matrix iteration analysis to scalar iteration analysis, greatly reducing the difficulty of analysis.
Optimization Solution#
In fact, the solution for $a,b,c$ was briefly discussed when we first introduced Muon in “Appreciation of Muon Optimizer: An Essential Leap from Vectors to Matrices”. The basic idea is to treat $a,b,c$ as optimization parameters, construct a loss function using the difference between $x_T$ and $1$, and then optimize with SGD.
The approach in this article is largely similar, but with slight adjustments. Evidently, the optimization results will depend on the singular value distribution. Previously, the author’s idea was to use random matrix SVD to simulate the real singular value distribution. However, SVD is time-consuming and labor-intensive, and the results also depend on the matrix shape. It now seems largely unnecessary. We instead sample points uniformly in $[0,1]$ and then select the $k$ points with the largest $|x_T-1|$ to construct the loss. This transforms it into a $\min\text{-}\max$ problem, minimizing the influence of the singular value distribution as much as possible:
import jax
import jax.numpy as jnp
from tqdm import tqdm
def loss(w, x, k=50):
for a, b, c in [w] * iters:
x = a * x + b * x**3 + c * x**5
return jnp.abs(x - 1).sort()[-k:].mean()
@jax.jit
def grad(w, x, tol=0.1):
G = lambda w, x: (g := jax.grad(loss)(w, x)) / jnp.fmax(jnp.linalg.norm(g), 1)
return 0.6 * G(w, x) + 0.2 * (G(w + tol / 2, x) + G(w - tol / 2, x))
iters = 5
x = jnp.linspace(0, 1, 10001)[1:]
w = jnp.array([1.5, -0.5, 0])
m, v = jnp.zeros_like(w), jnp.zeros_like(w)
lr = 1e-3
pbar = tqdm(range(20000), ncols=0, desc='Adam')
for i in pbar:
l, g = loss(w, x), grad(w, x)
m = 0.9 * m + 0.1 * g
v = 0.999 * v + 0.001 * g**2
w = w - lr * m / jnp.sqrt(v + 1e-20)
pbar.set_description(f'Loss: {l:.6f}, LR: {lr:.6f}')
if i in [10000]:
lr *= 0.1
Furthermore, the optimizer has been changed from SGD to Adam, which makes it easier to control the parameter update magnitude. To enhance the solution’s robustness against noise, we add some perturbation to $a,b,c$ and then mix in the gradients after perturbation. The optimization results from the script above are:
$$ (a,b,c)=(3.3748, -4.6969, 2.1433) $$It can be seen that this is not far from KellerJordan’s solution. Let’s further compare the differences between the two using images:
Initial Value Distribution#
Before further discussion, we need to clarify a question: how small of singular values do we really need to care about? This goes back to the distribution of $\boldsymbol{S}_0$. Since $\boldsymbol{S}_0$ is normalized by the Frobenius norm, $\mathop{\text{diag}}(\boldsymbol{S}_0)$ is effectively an $r$-dimensional unit vector. If all singular values are equal, then it can be deduced that each singular value is $1/\sqrt{r}$.
Therefore, according to the pigeonhole principle, in non-uniform cases, there must exist singular values smaller than $1/\sqrt{r}$. To be safe, we can consider a multiple, say 10 times, which means we should at least account for singular values of size $0.1/\sqrt{r}$. In practical situations, the probability of a matrix being strictly low-rank (i.e., singular values strictly equal to 0) is very small, so we generally assume the matrix is full rank, i.e., $r = \min(n,m)$. Thus, we should at least account for singular values of size $0.1/\sqrt{\min(n,m)}$.
Considering that the largest LLMs currently have a hidden_size around $8192 \sim 100^2$, based on this numerical estimate, a general-purpose Muon optimizer’s $\msign$ algorithm should at least account for singular values of size $0.001$, meaning it should be able to map $0.001$ to a value close to 1. From this perspective, both KellerJordan’s solution and our newly derived solution fall somewhat short.
Note: For a discussion on initial value distribution, please refer to “Iterative Orthogonalization Scaling Laws”.
Unlocking Constraints#
At this point, @YouJiacheng (one of the main proponents of Muon) on Twitter proposed a very clever idea: we can use different coefficients at each iteration step! That is, change the iteration to
$$ \boldsymbol{X}_{t+1} = a_{t+1}\boldsymbol{X}_t + b_{t+1}\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t) + c_{t+1}\boldsymbol{X}_t(\boldsymbol{X}_t^{\top}\boldsymbol{X}_t)^2 $$The advantage of this change is that once $T$ is chosen, the total computational load does not change at all. However, from a fitting perspective, what were originally only $3$ trainable parameters now become $3T$ parameters, greatly enhancing the fitting capability. He himself provided a reference result for a 6-step iteration:
t | a | b | c |
---|---|---|---|
1 | 3955/1024 | -8306/1024 | 5008/1024 |
2 | 3735/1024 | -6681/1024 | 3463/1024 |
3 | 3799/1024 | -6499/1024 | 3211/1024 |
4 | 4019/1024 | -6385/1024 | 2906/1024 |
5 | 2677/1024 | -3029/1024 | 1162/1024 |
6 | 2172/1024 | -1833/1024 | 682/1024 |
We can plot them for comparison:
For fairness, both KellerJordan’s solution and our solution (“Ours”) were also changed to $T=6$. It can be seen that YouJiacheng’s solution shows significant improvement in both smoothness and overall approximation quality, fully demonstrating the “full potential” unleashed by removing parameter sharing.
Try it Yourself#
How was YouJiacheng’s solution derived? The author shared his code here. His approach also uses Adam for solving, but it involves many different loss functions, which can be a bit complicated to understand. In fact, using our previous script with his provided initialization, equally good results can be obtained:
t | a | b | c |
---|---|---|---|
1 | 4140/1024 | -7553/1024 | 3571/1024 |
2 | 3892/1024 | -6637/1024 | 2973/1024 |
3 | 3668/1024 | -6456/1024 | 3021/1024 |
4 | 3248/1024 | -6211/1024 | 3292/1024 |
5 | 2792/1024 | -5759/1024 | 3796/1024 |
6 | 3176/1024 | -5507/1024 | 4048/1024 |
Reference code:
import jax
import jax.numpy as jnp
from tqdm import tqdm
def loss(w, x, k=50):
for a, b, c in w:
x = a * x + b * x**3 + c * x**5
return jnp.abs(x - 1).sort()[-k:].mean()
@jax.jit
def grad(w, x, tol=0.1):
G = lambda w, x: (g := jax.grad(loss)(w, x)) / jnp.fmax(jnp.linalg.norm(g), 1)
return 0.6 * G(w, x) + 0.2 * (G(w + tol / 2, x) + G(w - tol / 2, x))
iters = 6
x = jnp.linspace(0, 1, 10001)[1:]
w = jnp.array([[3.5, -6.04444444444, 2.84444444444]] * iters)
m, v = jnp.zeros_like(w), jnp.zeros_like(w)
lr = 1e-3
pbar = tqdm(range(20000), ncols=0, desc='Adam')
for i in pbar:
l, g = loss(w, x), grad(w, x)
m = 0.9 * m + 0.1 * g
v = 0.999 * v + 0.001 * g**2
w = w - lr * m / jnp.sqrt(v + 1e-20)
pbar.set_description(f'Loss: {l:.6f}, LR: {lr:.6f}')
if i in [10000]:
lr *= 0.1
Comparison below (labeled “Ours-X”):
As seen from the figures, compared to YouJiacheng’s solution, our results exhibit more oscillation but achieve a larger slope at $[0,0.001]$.
Other Solution Sets#
If readers desire solutions with less oscillation, they only need to increase the value of $k$. For instance, the results for $k=200$ are:
t | a | b | c |
---|---|---|---|
1 | 4059/1024 | -7178/1024 | 3279/1024 |
2 | 3809/1024 | -6501/1024 | 2925/1024 |
3 | 3488/1024 | -6308/1024 | 3063/1024 |
4 | 2924/1024 | -5982/1024 | 3514/1024 |
5 | 2439/1024 | -5439/1024 | 4261/1024 |
6 | 3148/1024 | -5464/1024 | 4095/1024 |
At this point, it’s almost identical to YouJiacheng’s solution (Ours-X2):
Additionally, here’s a 5-step solution for easy comparison with the original solution:
t | a | b | c |
---|---|---|---|
1 | 4.6182 | -12.9582 | 9.3299 |
2 | 3.8496 | -7.9585 | 4.3052 |
3 | 3.5204 | -7.2918 | 4.0606 |
4 | 3.2067 | -6.8243 | 4.2802 |
5 | 3.2978 | -5.7848 | 3.8917 |
Effect plot (Ours-X3):
Improved Initial Value#
So far, our discussion on solving for $a,b,c$ concludes. In summary, using different $a,b,c$ at each step significantly improves the convergence properties of the Newton-Schulz iteration without incurring any additional computational cost, truly a “free lunch”.
Besides optimizing the coefficients of the Newton-Schulz iteration, are there other ways to improve its convergence properties? Indeed there are. @johanwind, @YouJiacheng, @ZhangRuichong, and others have discovered that we can leverage the characteristics of Newton-Schulz iteration to almost freely improve the quality of the initial value, thereby increasing the convergence speed. @leloykun provided a reference implementation here.
Specifically, the main efforts to improve Newton-Schulz iteration can be summarized as “maximizing the convergence speed of singular values close to zero while ensuring convergence”. If we can pre-amplify these near-zero singular values slightly, we can also increase the convergence speed without changing the iterative algorithm. Currently, to compress singular values into $[0,1]$, we use Frobenius norm normalization $\boldsymbol{M}/\Vert\boldsymbol{M}\Vert_F$, which compresses singular values into
$$ \sigma_i \quad\to\quad \frac{\sigma_i}{\Vert\boldsymbol{M}\Vert_F} = \frac{\sigma_i}{\sqrt{\sum\limits_{j=1}^r \sigma_i^2}} \in [0, 1] $$While this approach achieves the goal, it also has the issue of over-compression. The most compact compression method should be $\sigma_i\to \sigma_i/\sigma_1$, i.e., spectral normalization. The problem is that the spectral norm is not as easy to compute as the Frobenius norm, so we reluctantly chose the Frobenius norm. However, we have
$$ \sigma_1 \quad\leq\quad \underbrace{\sqrt[8]{\sum_{j=1}^r \sigma_i^8}}_{\sqrt[4]{\Vert(\boldsymbol{M}^{\top}\boldsymbol{M})^2\Vert_F}}\quad\leq\quad \underbrace{\sqrt[4]{\sum_{j=1}^r \sigma_i^4}}_{\sqrt{\Vert\boldsymbol{M}^{\top}\boldsymbol{M}\Vert_F}} \quad\leq\quad \underbrace{\sqrt{\sum_{j=1}^r \sigma_i^2}}_{\Vert\boldsymbol{M}\Vert_F} $$This means that using $\sqrt[4]{\Vert(\boldsymbol{M}^{\top}\boldsymbol{M})^2\Vert_F}$ or $\sqrt{\Vert\boldsymbol{M}^{\top}\boldsymbol{M}\Vert_F}$ as normalization factors is theoretically better than $\Vert\boldsymbol{M}\Vert_F$. Very cleverly, under Newton-Schulz iteration, their computation is almost free! To understand this, let’s write out the first iteration step:
$$ \boldsymbol{X}_0 = \frac{\boldsymbol{M}}{\Vert\boldsymbol{M}\Vert_F},\qquad \boldsymbol{X}_1 = a\boldsymbol{X}_0 + b\boldsymbol{X}_0(\boldsymbol{X}_0^{\top}\boldsymbol{X}_t) + c\boldsymbol{X}_0(\boldsymbol{X}_0^{\top}\boldsymbol{X}_0)^2 $$It can be seen that $\boldsymbol{X}_0^{\top}\boldsymbol{X}_0$ and $(\boldsymbol{X}_0^{\top}\boldsymbol{X}_0)^2$ must be computed. Therefore, we can directly use them to calculate the Frobenius norm and then re-normalize. Reference code:
def msign(x, steps=5, eps=1e-20):
a, b, c, y = 3.4445, -4.7750, 2.0315, x.astype('bfloat16')
y = y.mT if x.shape[0] > x.shape[1] else y
y /= ((y**2).sum(axis=(-2, -1), keepdims=True) + eps)**0.5
for i in range(steps):
y4 = (y2 := y @ y.mT) @ y2
if i == 0:
n = ((y4**2).sum(axis=(-2, -1), keepdims=True) + eps)**0.125
y, y2, y4 = y / n, y2 / n**2, y4 / n**4
y = a * y + (b * y2 + c * y4) @ y
return y.mT if x.shape[0] > x.shape[1] else y
In practical tests, for a $100\times 100$ random Gaussian matrix, most of the smallest singular values after improvement were more than 2 times larger than before improvement, and the average singular value was also closer to 1. However, the Muon authors have also stated that it might introduce additional instability, so it has not yet been adopted into the official code.
Summary (formatted)#
This article introduced an optimization approach for computing $\msign$ using Newton-Schulz iteration. The results obtained show significant improvements in iteration convergence speed and effectiveness compared to Muon’s official solution.
Finally, it should be noted that for Muon, small-scale experimental results suggest that the computation precision of $\msign$ does not seem to have a necessary correlation with the model’s final performance. For small models, improving the precision of $\msign$ appears to only slightly accelerate convergence in the early stages, but the final results remain unchanged. It is currently unclear whether this conclusion holds for larger-scale models.
@online{kexuefm-10922,
title={Newton-Schulz Iteration for the msign Operator (Part 1)},
author={苏剑林},
year={2025},
month={05},
url={\url{https://kexue.fm/archives/10922}},
}