Skip to main content

Asymptotic Estimation of AdamW's Weight RMS

·2347 words
Table of Contents

This is a gemini-2.5-flash translation of a Chinese article.

It has NOT been vetted for errors. You should have the original article open in a parallel tab at all times.

In 《Why Adam’s Update RMS is 0.2?》, we used mean-field approximation to estimate Adam’s Update RMS. Shortly after, reader @EIFY pointed out that the same result had already appeared in the paper 《Rotational Equilibrium: How Weight Decay Balances Learning Across Neural Networks》. After reading it, I found that it contained not only the estimation of Update RMS but also the estimation of Weight RMS.

In other words, for models trained with AdamW, the RMS of their weights can be estimated asymptotically in advance. Doesn’t this conclusion seem a bit surprising to everyone? Anyway, the first time I saw it, I was quite surprised. Intuitively, the weight norm is learned by the model itself based on the training set, but it tells me that this is already hidden in the optimizer’s hyperparameters, which is very counter-intuitive.

In this article, we will still use the mean-field approximation method to reproduce the asymptotic estimation of Weight RMS.

Sliding Perspective
#

First, let’s review AdamW’s update rule:

$$ \begin{equation} \text{Adam}\color{skyblue}{\text{W}}:=\left\{\begin{aligned} &\boldsymbol{m}_t = \beta_1 \boldsymbol{m}_{t-1} + \left(1 - \beta_1\right) \boldsymbol{g}_t\\ &\boldsymbol{v}_t = \beta_2 \boldsymbol{v}_{t-1} + \left(1 - \beta_2\right) \boldsymbol{g}_t^2\\ &\hat{\boldsymbol{m}}_t = \boldsymbol{m}_t\left/\left(1 - \beta_1^t\right)\right.\\ &\hat{\boldsymbol{v}}_t = \boldsymbol{v}_t\left/\left(1 - \beta_2^t\right)\right.\\ &\boldsymbol{u}_t =\hat{\boldsymbol{m}}_t\left/\left(\sqrt{\hat{\boldsymbol{v}}_t} + \epsilon\right)\right.\\ &\boldsymbol{\theta}_t = \boldsymbol{\theta}_{t-1} - \eta_t (\boldsymbol{u}_t \color{skyblue}{ + \lambda_t \boldsymbol{\theta}_{t-1}}) \end{aligned}\right. \tag{1} \end{equation} $$

To reiterate, bold symbols here are by default vectors in $\mathbb{R}^d$. Vector multiplication and division (including squaring and taking square roots) are by default element-wise Hadamard products/quotients.

Similar to 《Why Adam’s Update RMS is 0.2?》, we consider $t\to\infty$ (for $\beta_1,\beta_2$) and $\epsilon\to 0$, so $\boldsymbol{u}_t=\boldsymbol{m}_t/\sqrt{\boldsymbol{v}_t}$. For now, let’s consider the case where $\eta_t, \lambda_t$ are constants, so their subscripts can be omitted. Let $\beta_3 = 1-\eta\lambda$, then we have:

$$ \begin{equation}\boldsymbol{\theta}_t = \beta_3\boldsymbol{\theta}_{t-1} + (1-\beta_3)(-\boldsymbol{u}_t/\lambda) \tag{2} \end{equation} $$

This equation shows that we can understand Weight Decay from the perspective of the Exponential Moving Average (EMA) of update quantities. This is a very meaningful shift in perspective and forms the basis for works such as 《How to set AdamW’s weight decay as you scale model and dataset size》 and 《Power Lines: Scaling Laws for Weight Decay and Batch Size in LLM Pre-training》.

Weighted Average
#

According to Eq. (2), we can expand $\boldsymbol{\theta}_t$ into a weighted average form:

$$ \begin{equation}\boldsymbol{\theta}_t = \beta_3^t\boldsymbol{\theta}_0 + (1-\beta_3)\sum_{i=1}^t \beta_3^{t-i} (-\boldsymbol{u}_i/\lambda) \tag{3} \end{equation} $$

Similarly, $\boldsymbol{m}_t$ and $\boldsymbol{v}_t$ can also be expanded as:

$$ \begin{equation}\boldsymbol{m}_t = (1 - \beta_1)\sum_{i=1}^t \beta_1^{t-i}\boldsymbol{g}_i,\qquad \boldsymbol{v}_t = (1 - \beta_2)\sum_{i=1}^t \beta_2^{i-j}\boldsymbol{g}_j^2 \tag{4} \end{equation} $$

Here’s a small detail: in the expression for $\boldsymbol{\theta}_t$, we retained $\boldsymbol{\theta}_0$, but in the expressions for $\boldsymbol{m}_t$ and $\boldsymbol{v}_t$, we did not retain $\boldsymbol{m}_0$ and $\boldsymbol{v}_0$. There are two reasons for this: 1. $\boldsymbol{m}$ and $\boldsymbol{v}$ are generally initialized to zero; 2. Even if they were not initialized to zero, the corresponding $\beta_1^t$ and $\beta_2^t$ would become sufficiently close to zero, so the effect of initialization can be ignored.

However, $\boldsymbol{\theta}$ represents model weights, which are usually not initialized to zero, and $\beta_3$ is often very close to 1. For the entire training cycle, $\beta_3^t$ may not approach zero sufficiently. Therefore, we explicitly retain $\beta_3^t$ and $\boldsymbol{\theta}_0$, making trade-offs as needed.

Quick Estimation
#

Our task is to estimate the Weight RMS, i.e., $\Vert\boldsymbol{\theta}_t\Vert_{RMS}$. As the name suggests, it is the Root Mean Square of each component:

$$ \begin{equation}\Vert\boldsymbol{\theta}\Vert_{RMS} = \sqrt{\frac{1}{d}\sum_{i=1}^d \theta_i^2},\qquad\qquad \text{其中 }\boldsymbol{\theta} = (\theta_1,\theta_2,\cdots,\theta_d) \tag{5} \end{equation} $$

The difference between it and the norm is simply the division by $\sqrt{d}$, so most properties of the norm also hold for RMS. For $\Vert\boldsymbol{\theta}_t\Vert_{RMS}$, we have a quick but not-so-accurate derivation method: directly taking $\Vert\cdot\Vert_{RMS}^2$ on both sides of Eq. (2), we can obtain:

$$ \begin{equation}\begin{aligned} \Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 =&\,\Vert\beta_3\boldsymbol{\theta}_{t-1} + (1-\beta_3)(-\boldsymbol{u}_t/\lambda)\Vert_{RMS}^2 \\[5pt] =&\, \beta_3^2\Vert\boldsymbol{\theta}_{t-1}\Vert_{RMS}^2 + (1-\beta_3)^2\Vert\boldsymbol{u}_t\Vert_{RMS}^2/\lambda^2 - 2\beta_3(1-\beta_3)\boldsymbol{\theta}_{t-1}\cdot\boldsymbol{u}_t/(\lambda d) \end{aligned} \tag{6} \end{equation} $$

Assuming $\boldsymbol{\theta}_{t-1}$ and $\boldsymbol{u}_t$ are nearly orthogonal, then $\boldsymbol{\theta}_{t-1}\cdot\boldsymbol{u}_t\approx 0$. This is generally a good approximation in high-dimensional spaces (refer to 《Angle Distribution of Two Random Vectors in n-Dimensional Space》). We have already calculated $\Vert\boldsymbol{u}_t\Vert_{RMS}$, and the answer is approximately $\sqrt{\frac{1-\beta_1}{1+\beta_1}}$. Finally, we consider the steady-state result, so $\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2=\Vert\boldsymbol{\theta}_{t-1}\Vert_{RMS}^2$. Thus, we have:

$$ \begin{equation}(1-\beta_3^2)\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx (1-\beta_3)^2 \frac{1-\beta_1}{1+\beta_1} /\lambda^2\qquad\Rightarrow\qquad \Vert\boldsymbol{\theta}_t\Vert_{RMS} \approx \sqrt{\frac{1-\beta_1}{1+\beta_1}\frac{\eta}{2\lambda}} \tag{7} \end{equation} $$

The approximation $\beta_3\approx 1$ was also used from the left to the right side of the equation. The final result will have some error because $\boldsymbol{\theta}_t\cdot\boldsymbol{u}_t\approx 0$ is not entirely accurate, but the conclusion that $\Vert\boldsymbol{\theta}_t\Vert_{RMS}\propto \sqrt{\eta/\lambda}$ is correct. Similar derivations also appear in 《Why Gradients Rapidly Increase Near the End of Training》.

Better Approximation
#

In many cases, it is sufficient to know that $\Vert\boldsymbol{\theta}_t\Vert_{RMS}\propto \sqrt{\eta/\lambda}$, which is a relatively general conclusion. For readers seeking a more accurate conclusion, we can use the mean-field method to obtain a better approximation. The cost is a significantly more complex calculation process, but the benefit is that we can gain more and clearer insights.

Step One
#

Starting from Eq. (3), the summation term itself has the form of a weighted average, so we first apply the first mean-field approximation:

$$ \begin{equation}\underbrace{\frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i} \boldsymbol{u}_i}_{\text{denoted as}\bar{\boldsymbol{u}}_t} = \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i} \frac{\hat{\boldsymbol{m}}_i}{\sqrt{\hat{\boldsymbol{v}}_i}}\approx \frac{\bar{\boldsymbol{m}}_t \,\,\triangleq\,\, \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i}\hat{\boldsymbol{m}}_i}{\sqrt{\bar{\boldsymbol{v}}_t \,\,\triangleq\,\, \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i}\hat{\boldsymbol{v}}_i}} \tag{8} \end{equation} $$

Now, returning to Eq. (3), since $\boldsymbol{\theta}_0$ is a randomly initialized vector, we can assume $\boldsymbol{\theta}_0$ is orthogonal to $\bar{\boldsymbol{u}}_t$. Thus, we have:

$$ \begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \beta_3^{2t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + (1-\beta_3^t)^2 \lambda^{-2}\Vert \bar{\boldsymbol{u}}_t\Vert_{RMS}^2 \tag{9} \end{equation} $$

Now we need to find $\Vert \bar{\boldsymbol{u}}_t\Vert_{RMS}^2$. Based on previous experience, we assume that $\boldsymbol{g}_j$ are independently and identically distributed following $\mathcal{N}(\boldsymbol{\mu},\boldsymbol{\sigma}^2)$, and then we calculate:

$$ \begin{equation}\mathbb{E}[\bar{\boldsymbol{u}}_t^2] \approx \mathbb{E}\left[\frac{\bar{\boldsymbol{m}}_t^2}{\bar{\boldsymbol{v}}_t}\right] \approx \frac{\mathbb{E}[\bar{\boldsymbol{m}}_t^2]}{\mathbb{E}[\bar{\boldsymbol{v}}_t]} \tag{10} \end{equation} $$

Finally, by averaging the components of $\mathbb{E}[\bar{\boldsymbol{u}}_t^2]$, we can use it as an approximation for $\Vert \bar{\boldsymbol{u}}_t\Vert_{RMS}^2$.

Step Two
#

Combining with Eq. (4), we get:

$$ \begin{gather} \sum_{i=1}^t \beta_3^{t-i}\hat{\boldsymbol{m}}_i = (1 - \beta_1)\sum_{i=1}^t \beta_3^{t-i} \sum_{j=1}^i \beta_1^{i-j}\boldsymbol{g}_j = (1 - \beta_1)\sum_{j=1}^t \frac{\beta_3^{t-j+1} - \beta_1^{t-j+1}}{\beta_3 - \beta_1}\boldsymbol{g}_j \tag{11} \\ \sum_{i=1}^t \beta_3^{t-i}\hat{\boldsymbol{v}}_i = (1 - \beta_2)\sum_{i=1}^t \beta_3^{t-i} \sum_{j=1}^i \beta_2^{i-j}\boldsymbol{g}_j^2 = (1 - \beta_2)\sum_{j=1}^t \frac{\beta_3^{t-j+1} - \beta_2^{t-j+1}}{\beta_3 - \beta_2}\boldsymbol{g}_j^2 \tag{12} \end{gather} $$

The last double summation simplification, if you have no idea, can be done by Kimi (refer to link). From the above equations, it is clear that $\bar{\boldsymbol{m}}_t$ and $\bar{\boldsymbol{v}}_t$ are weighted averages of the gradients and squared gradients, respectively. Therefore, calculating $\Vert \bar{\boldsymbol{u}}_t\Vert_{RMS}^2$ is essentially the same as calculating $\Vert \boldsymbol{u}_t\Vert_{RMS}^2$ in 《Why Adam’s Update RMS is 0.2?》, except for the weighting coefficients.

Step Three
#

Let’s first calculate the denominator:

$$ \begin{equation}\begin{aligned} \mathbb{E}[\bar{\boldsymbol{v}}_t] =&\,\frac{(1 - \beta_3)(1 - \beta_2)}{1 - \beta_3^t}\sum_{j=1}^t \frac{\beta_3^{t-j+1} - \beta_2^{t-j+1}}{\beta_3 - \beta_2}\mathbb{E}[\boldsymbol{g}_j^2] \\ =&\,\frac{(1 - \beta_3)(1 - \beta_2)}{1 - \beta_3^t}\sum_{j=1}^t \frac{\beta_3^{t-j+1} - \beta_2^{t-j+1}}{\beta_3 - \beta_2}(\boldsymbol{\mu}^2 + \boldsymbol{\sigma}^2) \\ =&\,\frac{(1 - \beta_3)(1 - \beta_2)}{(1 - \beta_3^t)(\beta_3 - \beta_2)}\left(\frac{\beta_3 - \beta_3^{t+1}}{1 - \beta_3} - \frac{\beta_2 - \beta_2^{t+1}}{1 - \beta_2}\right)(\boldsymbol{\mu}^2 + \boldsymbol{\sigma}^2) \\[5pt] \approx &\, \boldsymbol{\mu}^2 + \boldsymbol{\sigma}^2 \end{aligned} \tag{13} \end{equation} $$

The approximation in the last step is because, in practical training, $\beta_3$ will be sufficiently close to 1, and $\beta_2^{t+1}$ will be sufficiently close to 0, but $\beta_3^{t+1}$ may not be. Therefore, we replace $\beta_2^{t+1}$ with zero, and after simplification, replace the independent $\beta_3$ with 1, and finally add the approximation $\beta_3^{t+1}\approx \beta_3^t$.

Step Four
#

Next is $\mathbb{E}[\bar{\boldsymbol{m}}_t^2] = \mathbb{E}[\bar{\boldsymbol{m}}_t]^2 + \mathbb{V}ar[\bar{\boldsymbol{m}}_t]$. The calculation of $\mathbb{E}[\bar{\boldsymbol{m}}_t]$ is similar to that of $\mathbb{E}[\bar{\boldsymbol{v}}_t]$, yielding $\boldsymbol{\mu}$. For the calculation of $\mathbb{V}ar[\bar{\boldsymbol{m}}_t]$, we use the additive property of squared variances:

$$ \begin{equation}\begin{aligned} \mathbb{V}ar[\bar{\boldsymbol{m}}_t] =&\,\frac{(1 - \beta_3)^2(1 - \beta_1)^2}{(1-\beta_3^t)^2}\sum_{j=1}^t \left(\frac{\beta_3^{t-j+1} - \beta_1^{t-j+1}}{\beta_3 - \beta_1}\right)^2\mathbb{V}ar[\boldsymbol{g}_j] \\ =&\,\frac{(1 - \beta_3)^2(1 - \beta_1)^2}{(1-\beta_3^t)^2}\sum_{j=1}^t \left(\frac{\beta_3^{t-j+1} - \beta_1^{t-j+1}}{\beta_3 - \beta_1}\right)^2 \boldsymbol{\sigma}^2 \\ =&\,\frac{(1 - \beta_3)^2(1 - \beta_1)^2}{(1-\beta_3^t)^2(\beta_3 - \beta_1)^2}\left(\frac{\beta_3^2 - \beta_3^{2(t+1)}}{1 - \beta_3^2} + \frac{\beta_1^2 - \beta_1^{2(t+1)}}{1 - \beta_1^2} - 2\frac{\beta_1\beta_3 - \beta_1^{t+1}\beta_3^{t+1}}{1 - \beta_1\beta_3}\right) \boldsymbol{\sigma}^2 \\[5pt] \approx &\, (1 - \beta_3)(1 + \beta_3^t)\boldsymbol{\sigma}^2/2(1 - \beta_3^t) \end{aligned} \tag{14} \end{equation} $$

The reason for the approximation is the same as above.

Step Five
#

Substituting the calculation results from the previous two sections, we have:

$$ \begin{equation}\mathbb{E}[\bar{\boldsymbol{u}}_t^2] \approx \frac{\boldsymbol{\mu}^2 + (1 - \beta_3)(1 + \beta_3^t)\boldsymbol{\sigma}^2/2(1 - \beta_3^t)}{\boldsymbol{\mu}^2 + \boldsymbol{\sigma}^2} \tag{15} \end{equation} $$

Then:

$$ \begin{equation}\Vert\bar{\boldsymbol{u}}_t\Vert_{RMS}^2 \approx \frac{\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + (1 - \beta_3)(1 + \beta_3^t)/2(1 - \beta_3^t)}{\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + 1} \tag{16} \end{equation} $$

Finally, we have:

$$ \begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \beta_3^{2t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + (1-\beta_3^t)^2 \frac{\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + (1 - \beta_3)(1 + \beta_3^t)/2(1 - \beta_3^t)}{\lambda^2(\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + 1)} \tag{17} \end{equation} $$

Brief Analysis of Results
#

Eq. (17) looks quite complex; let’s observe a few special cases. First, consider the case where $\boldsymbol{\mu}=\boldsymbol{0}$. In this situation:

$$ \begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \beta_3^{2t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + (1-\beta_3^{2t}) (1 - \beta_3)/2\lambda^2 = \beta_3^{2t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + (1-\beta_3^{2t}) \eta/2\lambda \tag{18} \end{equation} $$

In particular, if we consider $t\to\infty$, or if $\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2$ is initialized to $\eta/2\lambda$, then we have:

$$ \begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS} \approx \sqrt{\frac{\eta}{2\lambda}} \tag{19} \end{equation} $$

This is the result given in the paper 《Rotational Equilibrium: How Weight Decay Balances Learning Across Neural Networks》, consistent with the original paper’s assumptions; it is the steady-state result of a random walk under zero mean. If we don’t consider $t\to\infty$ but instead consider the limit $\lambda\to 0$, then from Eq. (18) we will obtain:

$$ \begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + \eta^2 t \tag{20} \end{equation} $$

This indicates that without Weight Decay, $\Vert\boldsymbol{\theta}_t\Vert_{RMS}$ grows roughly at a rate of $\eta\sqrt{t}$. On the other hand, if the Batch Size is sufficiently large, causing the signal-to-noise ratio term $\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2$ to dominate, then from Eq. (17) we get:

$$ \begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \beta_3^{2t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + (1-\beta_3^t)^2 \frac{\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2}{\lambda^2(\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + 1)} \tag{21} \end{equation} $$

This might apply to special cases where the model actively needs to increase the Weight RMS. However, empirically, the probability of this situation occurring is generally small.

Simulation Experiment
#

We can use the following simulation script to simply verify the accuracy of the above:

import numpy as np

N, T = 10000, 100000
beta1, beta2 = 0.9, 0.95
m, v = 0, 0
w = np.random.randn(N) * 0.1
for i in range(T):
    g = np.random.randn(N)
    m = beta1 * m + (1 - beta1) * g
    v = beta2 * v + (1 - beta2) * g**2
    w = w - 0.001 * (m / v**0.5 + 0.1 * w)

weight_rms = (w**2).mean()**0.5
print(weight_rms)

You can change the weight initialization or the mean and variance of the gradients, etc., to see how well the final result matches Eq. (17). I’ve tried it myself, and overall, it’s quite reliable.

Sign Version
#

By simply adjusting the preceding proof, it can be applied to the “SignSGDM + Weight Decay” combination:

$$ \begin{equation}\text{SignSGDM}\color{skyblue}{\text{W}}:=\left\{\begin{aligned} &\boldsymbol{m}_t = \beta_1 \boldsymbol{m}_{t-1} + \left(1 - \beta_1\right) \boldsymbol{g}_t\\ &\boldsymbol{u}_t = \sign(\boldsymbol{m}_t)\\ &\boldsymbol{\theta}_t = \boldsymbol{\theta}_{t-1} - \eta_t (\boldsymbol{u}_t \color{skyblue}{ + \lambda_t \boldsymbol{\theta}_{t-1}}) \end{aligned}\right. \tag{22} \end{equation} $$

The modification is that since $\sign(\boldsymbol{m}_t)=\boldsymbol{m}_t/\sqrt{\boldsymbol{m}_t^2}$, the definition of $\bar{\boldsymbol{v}}_t$ needs to be changed to:

$$ \begin{equation}\bar{\boldsymbol{v}}_t \triangleq \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i}\boldsymbol{m}_i^2 \tag{23} \end{equation} $$

Then:

$$ \begin{equation}\mathbb{E}[\bar{\boldsymbol{v}}_t] = \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i}\mathbb{E}[\boldsymbol{m}_i^2] \approx \frac{1-\beta_3}{1-\beta_3^t}\sum_{i=1}^t \beta_3^{t-i}\mathbb{E}\left(\boldsymbol{\mu}^2 + \frac{1-\beta_1}{1 + \beta_1}\boldsymbol{\sigma}^2\right) = \boldsymbol{\mu}^2 + \frac{1-\beta_1}{1 + \beta_1}\boldsymbol{\sigma}^2 \tag{24} \end{equation} $$

For the calculation of $\mathbb{E}[\boldsymbol{m}_i^2]$, we can refer to 《Why Adam’s Update RMS is 0.2?》 or 《Rethinking Learning Rate and Batch Size (IV): EMA》. Using the above results, we obtain:

$$ \begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \beta_3^{2t}\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2 + (1-\beta_3^t)^2 \frac{\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + (1 - \beta_3)(1 + \beta_3^t)/2(1 - \beta_3^t)}{\lambda^2\left(\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2 + \frac{1-\beta_1}{1 + \beta_1}\right)} \tag{25} \end{equation} $$

In particular, considering the limit where $\boldsymbol{\mu}=0$ and $t\to\infty$, we have:

$$ \begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 \approx \sqrt{\frac{\eta}{2\lambda}\frac{1+\beta_1}{1 - \beta_1}} \tag{26} \end{equation} $$

This result is also reasonable, as the Update RMS of SignSGDMW is $\sqrt{\frac{1+\beta_1}{1 - \beta_1}}$ times that of AdamW. Therefore, its Weight RMS under the same $\eta, \lambda$ is also $\sqrt{\frac{1+\beta_1}{1 - \beta_1}}$ times.

Related Analysis#

As mentioned earlier, result (19) is consistent with the paper 《Rotational Equilibrium: How Weight Decay Balances Learning Across Neural Networks》, but our derivation method is completely different and yields the more general result (17). However, the original paper also has some interesting aspects, such as the concept of Total Update Contribution (TUC) it proposes, which is worth appreciating.

The idea behind TUC is as follows: Due to the presence of the momentum mechanism, the current gradient $\boldsymbol{g}_t$ does not only affect the current step; it also influences future steps (albeit with a “discount”). Therefore, assuming the number of training steps approaches infinity, we can consider the total contribution of the current gradient $\boldsymbol{g}_t$ to the entire training process. Specifically, for Adam, we have $\boldsymbol{u}_t=\boldsymbol{m}_t/\sqrt{\boldsymbol{v}_t}$. The contribution of the current $\boldsymbol{g}_t$ to $\boldsymbol{u}_t$ is $(1-\beta_1)\boldsymbol{g}_t/\sqrt{\boldsymbol{v}_t}$. In the next step, $\boldsymbol{g}_t$ will be discounted (multiplied by $\beta_1$), and the denominator will change to $\boldsymbol{v}_{t+1}$, and so on. Thus, the total contribution can be defined as:

$$ \begin{equation}\tilde{\boldsymbol{u}}_t = \sum_{k=t}^{\infty} (1-\beta_1)\beta_1^{k-t}\frac{\boldsymbol{g}_t}{\sqrt{\boldsymbol{v}_k}} \tag{27} \end{equation} $$

In this way, we decompose the updates $\boldsymbol{u}_1,\boldsymbol{u}_2,\boldsymbol{u}_3,\cdots$ into updates $\tilde{\boldsymbol{u}}_1,\tilde{\boldsymbol{u}}_2,\tilde{\boldsymbol{u}}_3,\cdots$. The advantage of this is that each $\tilde{\boldsymbol{u}}$ depends only on a single-step gradient, allowing us to repeat the derivation from the Quick Estimation section:

$$ \begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2 = \Vert\beta_3\boldsymbol{\theta}_{t-1} + (1-\beta_3)(-\tilde{\boldsymbol{u}}_t/\lambda)\Vert_{RMS}^2 \approx \beta_3^2\Vert\boldsymbol{\theta}_{t-1}\Vert_{RMS}^2 + (1-\beta_3)^2\Vert\tilde{\boldsymbol{u}}_t\Vert_{RMS}^2/\lambda^2 \tag{28} \end{equation} $$

The final approximation relies on $\boldsymbol{\theta}_{t-1}\cdot\tilde{\boldsymbol{u}}_t\approx 0$. We assert that $\boldsymbol{\theta}_{t-1}\cdot\tilde{\boldsymbol{u}}_t$ is closer to zero than $\boldsymbol{\theta}_{t-1}\cdot\boldsymbol{u}_t$ because $\tilde{\boldsymbol{u}}_t$ depends only on the current gradient $\boldsymbol{g}_t$, and $\boldsymbol{\theta}_{t-1}$ has not yet encountered $\boldsymbol{g}_t$. Therefore, they are independent variables, and when $\boldsymbol{g}_t$ is assumed to have zero mean, $\boldsymbol{\theta}_{t-1}\cdot\tilde{\boldsymbol{u}}_t\approx 0$ often holds easily. To estimate $\Vert\tilde{\boldsymbol{u}}_t\Vert_{RMS}^2$, the original paper directly assumes that $\boldsymbol{g}_t/\sqrt{\boldsymbol{v}_k}$ has the same direction and unit RMS. Thus:

$$ \begin{equation}\Vert\tilde{\boldsymbol{u}}_t\Vert_{RMS} = \sum_{k=t}^{\infty} (1-\beta_1)\beta_1^{k-t}\left\Vert\frac{\boldsymbol{g}_t}{\sqrt{\boldsymbol{v}_k}}\right\Vert_{RMS} = \sum_{k=t}^{\infty} (1-\beta_1)\beta_1^{k-t} = 1 \tag{29} \end{equation} $$

Substituting into Eq. (28) and combining with the same approximation treatment from the Quick Estimation section, we solve for:

$$ \begin{equation}\Vert\boldsymbol{\theta}_t\Vert_{RMS} \approx \sqrt{\frac{\eta}{2\lambda}} \tag{30} \end{equation} $$

However, if we confine ourselves to the original paper, we find many approximations that seem arbitrary. For example, $\boldsymbol{v}_t$ also contains $\boldsymbol{g}_t$, so stating that $\tilde{\boldsymbol{u}}_t$ only includes the influence of the current $\boldsymbol{g}_t$ is not entirely accurate. Furthermore, the assertion $\Vert\boldsymbol{g}_t/\sqrt{\boldsymbol{v}_k}\Vert_{RMS}=1$ appears somewhat forced. But when viewed in the context of this article, we realize that under the mean-field approximation, the various operations in the original paper become quite reasonable, implying that the original paper actually implicitly used the mean-field method.

Summary (formatted)
#

In this article, we used mean-field approximation to derive an interesting and potentially surprising conclusion: for models trained with AdamW, the RMS of their weights can also be asymptotically estimated, and generally, it only depends on the learning rate and Weight Decay.

@online{kexuefm-11307,
        title={Asymptotic Estimation of AdamW's Weight RMS},
        author={苏剑林},
        year={2025},
        month={10},
        url={\url{https://kexue.fm/archives/11307}},
}