Skip to main content

Hyperparameter Scaling Laws Across Model Scales

·3590 words
Table of Contents

This is a gemini-2.5-flash translation of a Chinese article.

It has NOT been vetted for errors. You should have the original article open in a parallel tab at all times.

By Su Jianlin | 2025-03-13 | 30813 readers

As is well known, the cost of fully training a large LLM is expensive, which means we cannot repeatedly test hyperparameters directly on large LLMs. A very natural idea is to carefully search for hyperparameters on small models with the same structure, find the optimal combination, and then transfer it directly to large models. Although this idea is simple, it is not trivial to implement. It requires us to understand the scaling laws between common hyperparameters and model scales, and muP is a practical application of this idea.

muP, sometimes written as $\mu P$, stands for Maximal Update Parametrization. It originates from the paper 《Tensor Programs V: Tuning Large Neural Networks via Zero-Shot Hyperparameter Transfer》. With the popularization of LLM training, it has gradually become one of the de facto standards for scientific machine learning experimentation.

General Idea of the Method
#

Before delving into the topic, I must first express a slight frustration that the original muP paper is excessively obscure, and its conclusions are not clearly articulated, unnecessarily increasing the difficulty of understanding. Therefore, I will try to reproduce muP’s conclusions in a (self-proclaimed) concise manner.

First, let’s state the conclusion: muP primarily studies the transfer laws of hyperparameters across model scales. Here are a few keywords:

  1. Hyperparameters, currently mainly referring to learning rate;
  2. Model scale, currently mainly referring to model width;
  3. The core here is “transfer”.

Please note, muP does not study what the optimal hyperparameters are, but only how optimal hyperparameters change with model scale. Therefore, we need to search for the optimal hyperparameter combination on a small model and then transfer it to a large model. This is the use case and method of muP.

The principle behind deriving muP is to ensure that the model’s forward pass, backward pass, loss increment, and feature changes do not significantly vary with changes in model scale:

  1. Specifically, the approach is to analyze the order of magnitude of initialization and then assume that these conclusions represent the laws of subsequent optimization;
  2. In plain terms, it assumes that if initialization is done well, subsequent training will automatically follow the correct trajectory (a good start is half the battle?);
  3. Of course, one could also tell a story about the Law of Large Numbers or the Central Limit Theorem for this assumption, but I personally don’t think it’s necessary.

Forward Pass
#

We begin our discussion with the forward pass, as this is a relatively simple and well-established part. First, consider a linear layer $\boldsymbol{Y}=\boldsymbol{X}\boldsymbol{W}$, where $\boldsymbol{X}\in\mathbb{R}^{b\times d_{in}}$ and $\boldsymbol{W}\in\mathbb{R}^{d_{in}\times d_{out}}$. We use RMS (Root Mean Square) as the metric for matrix scale. For example:

$$ \text{RMS}(\boldsymbol{W}) = \sqrt{\frac{1}{d_{in} d_{out}}\sum_{i=1}^{d_{in}} \sum_{j=1}^{d_{out}} W_{i,j}^2} $$

We know that to make the RMS of $\boldsymbol{X}$ and $\boldsymbol{Y}$ roughly equal during the initialization phase (short for “stable”), $\boldsymbol{W}$ should use:

LeCun initialization: Random initialization with “mean of 0 and variance of $1/d_{in}$”.

This is already one of the fundamental conclusions in deep learning, so we will not elaborate on its derivation. Readers who are not yet familiar can refer to previous blog posts such as 《从几何视角来理解模型参数的初始化策略》 (Understanding Model Parameter Initialization Strategies from a Geometric Perspective) and 《浅谈Transformer的初始化、参数化与标准化》 (A Brief Discussion on Transformer Initialization, Parameterization, and Normalization).

Next, we consider a nonlinear layer $\boldsymbol{Y}=\phi(\boldsymbol{X}\boldsymbol{W})$, where $\phi$ is an element-wise activation function. If we still want to maintain approximate equality between the RMS of $\boldsymbol{X}$ and $\boldsymbol{Y}$, the result will be slightly different. For example, with $\text{relu}$ activation, we get:

Kaiming initialization: Random initialization with “mean of 0 and variance of $2/d_{in}$”.

It is easy to see that Kaiming initialization, compared to LeCun initialization, only differs by a constant factor of 2 in variance (which is independent of the model scale). It can be proven that results for other activation functions are similar. So we can draw the following conclusion:

fan_in initialization: To ensure the stability of the forward pass, one should use random initialization with “mean of 0 and variance proportional to $1/d_{in}$”.

This conclusion can also be understood as “the influence of activation functions is independent of model scale.” Therefore, if we only want to analyze the effect of model scale, we can ignore the presence of element-wise activation functions and directly derive the scaling law $\propto 1/d_{in}$ from LeCun initialization.

Backward Pass
#

Now let’s continue analyzing the backward pass (gradients). Note that here we assume variables and their gradients have the same shape. Then we can calculate:

$$ \begin{align} \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}} =&\, \boldsymbol{X}^{\top}\left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\otimes \phi'(\boldsymbol{X}\boldsymbol{W})\right) \\[5pt] \frac{\partial\mathcal{L}}{\partial \boldsymbol{X}} =&\, \left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\otimes \phi'(\boldsymbol{X}\boldsymbol{W})\right)\boldsymbol{W}^{\top} \end{align} $$

The first formula is the gradient of parameters within the current layer, and the second formula is the gradient propagated backward from this layer. $\otimes$ is the Hadamard product, and $\phi'$ is the derivative function of $\phi$.

Note a fact: the derivatives of commonly used activation functions can be bounded by a constant (which is independent of scale). So, at least in terms of order of magnitude, we can write:

$$ \begin{align} \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}} =&\, \boldsymbol{X}^{\top}\left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\otimes \phi'(\boldsymbol{X}\boldsymbol{W})\right) \sim \boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}} \\[5pt] \frac{\partial\mathcal{L}}{\partial \boldsymbol{X}} =&\, \left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\otimes \phi'(\boldsymbol{X}\boldsymbol{W})\right)\boldsymbol{W}^{\top}\sim \frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\boldsymbol{W}^{\top} \end{align} $$

Let’s first look at the second formula. Compared to $\boldsymbol{Y}=\boldsymbol{X}\boldsymbol{W}$, the matrix multiplied on its right side becomes $\boldsymbol{W}^{\top}$. Then, according to the conclusion from the previous section, if we want to maintain RMS stability for the backward pass, the initialization of $\boldsymbol{W}$ should be:

fan_out initialization: Random initialization with “mean of 0 and variance of $1/d_{out}$”.

When $d_{in}\neq d_{out}$, the requirements for forward and backward passes conflict. At this point, some proposed a compromise strategy:

Xavier initialization: Random initialization with “mean of 0 and variance of $2/(d_{in} + d_{out})$”.

This is also called “fan_avg initialization,” because it simply takes the arithmetic average of $d_{in}$ and $d_{out}$. Other averaging methods can also be considered; refer to 《初始化方法中非方阵的维度平均策略思考》 (Considerations on Dimension Averaging Strategies for Non-Square Matrices in Initialization Methods). Xavier initialization appears to balance both forward and backward passes, but one could also say it balances neither. A better approach is to design the model so that most parameters are square matrices, as discussed in the model family later.

Loss Increment
#

With the groundwork laid by the forward and backward passes, we can now attempt to analyze the increment of the loss function. Consider the change in the loss function when $\boldsymbol{W}\to \boldsymbol{W} + \Delta\boldsymbol{W}$:

$$ \Delta \mathcal{L} = \mathcal{L}(\boldsymbol{W} + \Delta\boldsymbol{W}) - \mathcal{L}(\boldsymbol{W})\approx \left\langle\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}, \Delta\boldsymbol{W}\right\rangle_F $$

Here, $\langle\cdot,\cdot\rangle_F$ is the Frobenius inner product, which means calculating the vector inner product after flattening the matrix into a vector. Considering gradient descent $\Delta\boldsymbol{W} = -\eta \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}$, where $\eta$ is naturally the learning rate, and combining with the gradient equation for $\boldsymbol{W}$, we have:

$$ \Delta \mathcal{L}\approx -\eta\left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right\Vert_F^2\sim -\eta \left\Vert\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\right\Vert_F^2 $$

In fact, this equation already tells us why the same learning rate $\eta$ cannot be used across different model scales:

  • $\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}$ is a $d_{in}\times d_{out}$ matrix;
  • $\left\Vert\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\right\Vert_F^2$ is the sum of squares of $d_{in}\times d_{out}$ elements;
  • $\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}$ is exactly the product of forward and backward quantities;
  • If both forward and backward passes are stable, then each element of $\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}$ is $\mathcal{O}(1)$;
  • So $\left\Vert\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}\right\Vert_F^2$ is $\mathcal{O}(d_{in} d_{out})$.

The 4th point might require further elaboration. $\boldsymbol{X}^{\top}$ is a $d_{in}\times b$ matrix, and $\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}$ is a $b\times d_{out}$ matrix. Their product involves $d_{in} d_{out}$ pairs of $b$-dimensional vectors performing inner products, where each inner product is a sum of $b$ terms. And the loss $\mathcal{L}$ is usually averaged over samples (i.e., it includes a division by $b$ operation). Therefore, if both $\boldsymbol{X}^{\top}$ and $\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}$ are scale-independent, their product will generally also be scale-independent [i.e., their RMS values are all $\mathcal{O}(1)$].

The final conclusion indicates that if we directly apply the learning rate of a small model to a large model, then for sufficiently large models, the loss increment at each step will explode as the parameter scale (i.e., $d_{in} d_{out}$) increases. This means it’s impossible to replicate the convergence process of a small model, and may even lead to non-convergence due to excessively large steps.

At this point, one might think of scaling $\Delta\mathcal{L}$ by setting $\eta\propto 1/(d_{in} d_{out})$. In fact, this idea aligns with muP’s approach, but in practical scenarios, due to the aforementioned incompatibility between forward and backward passes, the fourth point “if both forward and backward passes are stable, then each element of $\boldsymbol{X}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}}$ is $\mathcal{O}(1)$” does not always hold true, making the actual situation more complex.

Model Assumptions
#

Now let’s consider a scenario closer to practice. Our task is to train a model mapping from $\mathbb{R}^{d_{in}}$ to $\mathbb{R}^{d_{out}}$, where $d_{in}$ and $d_{out}$ are determined by the data and cannot be changed. As we stated at the beginning, muP aims to study how hyperparameters scale with model size. Therefore, all fixed quantities are considered constants or $\mathcal{O}(1)$. For instance, an initialization variance of $1/d_{in}$ is equivalent to saying the initialization variance is $\mathcal{O}(1)$.

What we can change are parts like the model architecture and the number of parameters. However, muP primarily considers the scaling laws related to width, so let’s define the model architecture. The model family mainly considered here is:

$$ \begin{gathered} \boldsymbol{Y}_{in} = \boldsymbol{X} \boldsymbol{W}_{in} \\[5pt] \boldsymbol{Y}_{out} = \text{NN}(\boldsymbol{Y}_{in},\boldsymbol{\Theta}) \\[5pt] \boldsymbol{Z} = \boldsymbol{Y}_{out} \boldsymbol{W}_{out} \end{gathered} $$

Where:

  • $\boldsymbol{X}\in\mathbb{R}^{b\times d_{in}}$ (including batch size);
  • $\boldsymbol{W}_{in} \in \mathbb{R}^{d_{in}\times d}, \boldsymbol{W}_{out} \in \mathbb{R}^{d\times d_{out}}$;
  • $\text{NN}$ is any neural network mapping from $\mathbb{R}^d$ to $\mathbb{R}^d$;
  • Here, $d$ is actually what we commonly refer to as the hidden size;
  • We can arbitrarily increase $d$ to boost the model’s parameter count and potential;
  • muP aims to study the scaling laws of hyperparameters with respect to $d$.

More specifically, the $\text{NN}$ we consider here is a $K$-layer MLP:

$$ \begin{aligned} \boldsymbol{Y}_0 =&\, Y_{in} \\[5pt] \boldsymbol{Y}_{k+1} =&\, \phi(\boldsymbol{Y}_k \boldsymbol{W}_{k+1}) \\[5pt] \boldsymbol{Y}_{out} =&\, \boldsymbol{Y}_K \end{aligned} $$

Here $\boldsymbol{\Theta}=\{\boldsymbol{W}_1,\boldsymbol{W}_2,\cdots,\boldsymbol{W}_K\}$, and $\boldsymbol{W}_k\in\mathbb{R}^{d\times d}$, meaning all are $d\times d$ square matrices, and all use fan_in initialization (equivalently, also fan_out initialization).

To clarify, assuming all parameter matrices are $d\times d$ square matrices here is purely for simplifying the analysis and not a strict requirement. The true purpose here is to assume that the parameters within $\text{NN}$ do not have scale-independent shapes. For example, a shape like $d\times 64$ is not allowed because $64$ is a constant, but a shape like $d\times 4d$ is allowed, because regardless of fan_in, fan_out, or fan_avg initialization, the variance will be proportional to $1/d$.

Putting It Together
#

Once the specific model is established, we can assemble all the previous conclusions. The parameters to be updated are divided into three parts: $\boldsymbol{W}_{in}$, $\boldsymbol{\Theta}$, and $\boldsymbol{W}_{out}$. Their gradients are calculated as follows:

$$ \begin{align} \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}} =&\, \boldsymbol{Y}_{out}^{\top}\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}} \\[6pt] \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k} =&\, \frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k} \cdot\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}_{out}} = \frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k} \cdot\left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}\boldsymbol{W}_{out}^{\top}\right) \\[6pt] \frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}} =&\, \boldsymbol{X}^{\top} \frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}_{in}} = \boldsymbol{X}^{\top} \left(\frac{\partial\boldsymbol{Y}_{out}}{\partial \boldsymbol{Y}_{in}}\cdot\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}_{out}}\right) = \boldsymbol{X}^{\top} \left(\frac{\partial\boldsymbol{Y}_{out}}{\partial \boldsymbol{Y}_{in}}\cdot\left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}\boldsymbol{W}_{out}^{\top}\right)\right) \\[6pt] \end{align} $$

The $\cdot$ operation here requires a slight explanation: Both $\boldsymbol{Y}_{in}$ and $\boldsymbol{Y}_{out}$ are matrices, so $\frac{\partial\boldsymbol{Y}_{out}}{\partial \boldsymbol{Y}_{in}}$ is, in principle, a fourth-order tensor. The chain rule $\frac{\partial\boldsymbol{Y}_{out}}{\partial \boldsymbol{Y}_{in}}\cdot\frac{\partial\mathcal{L}}{\partial \boldsymbol{Y}_{out}}$ is actually a multiplication of higher-order tensors. However, we won’t elaborate on it here, so we simply use a $\cdot$ to denote it. Readers only need to know that it is a general generalization of matrix multiplication.

Now let’s observe the patterns:

  • All three equations contain $\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}$;
  • The latter two equations both contain $\boldsymbol{W}_{out}^{\top}$;
  • All $\boldsymbol{W}_k$ are square matrices, and $\frac{\partial\boldsymbol{Y}_{out}}{\partial \boldsymbol{Y}_{in}}$ and $\frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k}$ are stable [RMS is $\mathcal{O}(1)$];
  • If $\boldsymbol{W}_{in}$ also uses fan_in initialization, then $\boldsymbol{Y}_{out}$ is also stable;
  • To make $\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}\boldsymbol{W}_{out}^{\top}$ stable, the initialization variance should be $1/d_{out}$. However, $d_{out}$ is scale-independent, effectively a constant.

As a result:

  • The RMS of $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}}$ is $\mathcal{O}(1)$, and $\left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}}\right\Vert_F^2$ is the sum of squares of $d\times d_{out}$ elements, so its magnitude is $\mathcal{O}(d\times d_{out})$. Don’t forget that $d_{out}$ is a constant, so it’s effectively $\mathcal{O}(d)$. Therefore, to achieve an $\mathcal{O}(1)$ $\Delta\mathcal{L}$, its learning rate must satisfy $\eta_{out}\propto 1/d$;
  • $\left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}\right\Vert_F^2$ is the sum of $d^2$ terms. The RMS of $\frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k}$ and $\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}$ are both $\mathcal{O}(1)$. If we directly set the initialization variance of $\boldsymbol{W}_{out}$ to be $\propto 1/d^2$, then the RMS of $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}$ will be $\mathcal{O}(1/d)$. After squaring and summing, it becomes exactly $\mathcal{O}(1)$, so the learning rate does not need to change;
  • At this point, the RMS of $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}}$ is also $\mathcal{O}(1/d)$. However, $\left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}}\right\Vert_F^2$ is only the sum of squares of $d_{in}\times d$ elements, so the result is $\mathcal{O}(1/d)$. To achieve an $\mathcal{O}(1)$ $\Delta\mathcal{L}$, the learning rate instead needs to be amplified by a factor of $d$ to counteract this effect, i.e., $\eta_{in}\propto d$.

Feature Changes
#

The above results are correct, but upon careful consideration, we find an issue in the derivation process. Points 2 and 3 above are both based on the assumption that “we directly set the initialization variance of $\boldsymbol{W}_{out}$ to be $\propto 1/d^2$”. However, this setting currently lacks direct justification. Without further explanation, the derivation process would be incomplete.

In fact, if we only consider the requirement $\Delta \mathcal{L}=\mathcal{O}(1)$, it’s indeed impossible to rule out other possibilities. For instance, if the initialization variance of $\boldsymbol{W}_{out}$ is set to $\propto 1/d$, then the RMS of $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}$ would be $\mathcal{O}(1/\sqrt{d})$. After squaring and summing, it becomes $\mathcal{O}(d)$, and then a learning rate $\eta\propto 1/d$ could also achieve $\Delta \mathcal{L}=\mathcal{O}(1)$. Therefore, to explain the necessity of setting the initialization variance of $\boldsymbol{W}_{out}$ to be $\propto 1/d^2$, new conditions need to be introduced.

The loss function $\mathcal{L}$ is a macroscopic, or external, metric of the model. Looking solely at its changes is insufficient to explain all results; thus, we need to delve into the model’s internal workings. Specifically, we want the change in the output of each layer (often called features, or sometimes activations) to also be scale-invariant. For example, for a linear layer $\boldsymbol{Y}_k = \boldsymbol{Y}_{k-1} \boldsymbol{W}_k$, the output change caused by parameters $\boldsymbol{W}_k\to \boldsymbol{W}_k + \Delta \boldsymbol{W}_k$ is:

$$ \Delta\boldsymbol{Y}_k = \boldsymbol{Y}_{k-1} (\boldsymbol{W}_k + \Delta \boldsymbol{W}_k) - \boldsymbol{Y}_{k-1} \boldsymbol{W}_k = \boldsymbol{Y}_{k-1} \Delta\boldsymbol{W}_k $$

Note that $\boldsymbol{Y}_{k-1}\in\mathbb{R}^{b\times d}$ and $\Delta\boldsymbol{W}_k\in\mathbb{R}^{d\times d}$, so $\boldsymbol{Y}_{k-1} \Delta\boldsymbol{W}_k$ is the inner product of $b\times d$ pairs of $d$-dimensional vectors. Note that $\Delta\boldsymbol{W}_k$ is a carefully designed update quantity. It’s unlikely to be independent of $\boldsymbol{Y}_{k-1}$ as in initialization. Therefore, the “inner product of $d$-dimensional vector pairs” is more likely to be $\mathcal{O}(d)$ (a $d$-dimensional inner product involves summing $d$ terms). Thus, if the RMS of $\Delta\boldsymbol{Y}_{k-1}$ is $\mathcal{O}(1)$, then the RMS of $\Delta\boldsymbol{Y}_k$ can be considered $\mathcal{O}(d\times \text{RMS}(\Delta \boldsymbol{W}_k))$.

Therefore, to make the RMS of $\Delta\boldsymbol{Y}_k$ equal to $\mathcal{O}(1)$, we get an additional requirement for $\Delta \boldsymbol{W}_k$:

$$ \text{RMS}(\Delta \boldsymbol{W}_k) = \mathcal{O}(1 / d) $$

Combining $\Delta \boldsymbol{W}_k = -\eta\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}$ and $\Delta\mathcal{L}=\mathcal{O}(1)$, we can then derive the result that “the initialization variance of $\boldsymbol{W}_{out}$ is set to $\propto 1/d^2$”.

(Note: This section relies on the guidance of @Chenyu Zheng. Many thanks!)

Adam Version
#

The above was muP for SGD. For Adam, we typically use SignSGD as an approximation for order-of-magnitude analysis:

  • $\Delta \boldsymbol{W} = -\eta \mathop{\text{sign}}\left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right)$;
  • $\Delta \mathcal{L} \approx -\eta \left|\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right|_1$;
  • Here, $|\cdot|_1$ refers to taking the absolute value of each element and then summing them.

Regarding the SignSGD approximation itself, readers can also refer to articles such as 《当Batch Size增大时,学习率该如何随之变化?》 (How Should the Learning Rate Change When Batch Size Increases?) and 《Adam的epsilon如何影响学习率的Scaling Law?》 (How Adam’s Epsilon Affects the Learning Rate Scaling Law?). We will not elaborate on it here. In summary, SignSGD is a commonly used approximation method when analyzing Adam-related scaling laws.

Now we can analyze by imitating the SGD process:

  • The RMS of $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}}$ is $\mathcal{O}(1)$. $\left|\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}}\right|_1$ is the sum of $d\times d_{out}$ elements, so its magnitude is $\mathcal{O}(d\times d_{out}) = \mathcal{O}(d)$. Therefore, its learning rate must satisfy $\eta_{out}\propto 1/d$ to counteract the scale effect;
  • $\left|\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}\right|_1$ is the sum of $d^2$ elements. The RMS of $\frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k}$ and $\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}$ are both $\mathcal{O}(1)$. If we set the initial variance of $\boldsymbol{W}_{out}$ to be $\propto 1/d^2$, then the RMS of $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}$ will be $\mathcal{O}(1/d)$. After summing $d^2$ elements, it becomes $\mathcal{O}(d)$, so the learning rate should transform as $\eta_k\propto 1/d$ to counteract the scale effect;
  • At this point, the RMS of $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}}$ is also $\mathcal{O}(1/d)$. However, $\left|\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}}\right|_1$ is merely the sum of $d_{in}\times d$ elements, so it is already $\mathcal{O}(1)$, and thus the learning rate does not need to change with scale.

(Note: Readers can verify that the condition $\text{RMS}(\Delta \boldsymbol{W}_k) = \mathcal{O}(1 / d)$ is satisfied.)

Muon Version
#

Next, the analysis of Muon is naturally indispensable. As for Muon itself, we have already provided detailed introductions in 《Muon优化器赏析:从向量到矩阵的本质跨越》 (Appreciation of Muon Optimizer: An Essential Leap from Vectors to Matrices) and 《Muon续集:为什么我们选择尝试Muon?》 (Muon Sequel: Why Did We Choose to Try Muon?). We will not repeat them here. Similar to Adam using SignSGD, we use MSignSGD to approximate Muon:

  • $\Delta \boldsymbol{W} = -\eta \mathop{\text{msign}}\left(\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right)$;
  • $\Delta \mathcal{L} \approx -\eta \left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right\|_*$ (proof can be found in 《Muon优化器赏析:从向量到矩阵的本质跨越》);
  • Here, $\Vert\cdot\Vert_*$ refers to the Nuclear norm, which is the sum of all singular values of a matrix;
  • The Nuclear norm is not easy to compute, but the Frobenius norm (F-norm) is. It equals the square root of the sum of squares of all singular values of a matrix;
  • We use the F-norm as an approximation for the Nuclear norm. Therefore, $\Delta \mathcal{L} \approx -\eta \left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right\|_*\approx -\eta \left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}}\right\|_F$;
  • The F-norm is also equal to the square root of the sum of squares of all elements of the matrix.

So we can begin the analysis process:

  • The RMS of $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}}$ is $\mathcal{O}(1)$, so the magnitude of $\left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{out}}\right\|_*$ is $\mathcal{O}(\sqrt{d\times d_{out}}) = \mathcal{O}(\sqrt{d})$. To eliminate the scale effect, its learning rate must satisfy $\eta_{out}\propto 1/\sqrt{d}$;
  • $\left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}\right\|_F$ is the square root of the sum of squares of $d^2$ elements. The RMS of $\frac{\partial \boldsymbol{Y}_{out}}{\partial \boldsymbol{W}_k}$ and $\frac{\partial\mathcal{L}}{\partial \boldsymbol{Z}}$ are both $\mathcal{O}(1)$. If we set the initial variance of $\boldsymbol{W}_{out}$ to be $\propto 1/d^2$, then the RMS of $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_k}$ will be $\mathcal{O}(1/d)$. After summing squares and then taking the square root, the result is $\mathcal{O}(1)$, so the learning rate does not need to change;
  • At this point, the RMS of $\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}}$ is also $\mathcal{O}(1/d)$. However, $\left\Vert\frac{\partial\mathcal{L}}{\partial \boldsymbol{W}_{in}}\right\|_F$ is merely the square root of the sum of squares of $d_{in}\times d$ elements, so it is $\mathcal{O}(1/\sqrt{d})$. The learning rate, conversely, needs to be amplified by a factor of $\sqrt{d}$ to counteract this effect, i.e., $\eta_{in}\propto \sqrt{d}$.

(Note: While the conclusion for Muon here is correct, it does not satisfy the condition $\text{RMS}(\Delta \boldsymbol{W}_k) = \mathcal{O}(1 / d)$ because, upon closer examination, this condition relies on the assumption that the update quantity is element-wise, which Muon does not comply with. Therefore, it is practically inapplicable. We have not extensively discussed this here but rather directly adopted the conclusion that “the initialization variance of $\boldsymbol{W}_{out}$ is set to $\propto 1/d^2$”, bypassing this condition.)

Summary of Conclusions
#

Summarizing the above conclusions:

$\boldsymbol{W}_{in}$ Variance$\boldsymbol{W}_{in}$ Learning Rate$\boldsymbol{W}_k$ Variance$\boldsymbol{W}_k$ Learning Rate$\boldsymbol{W}_{out}$ Variance$\boldsymbol{W}_{out}$ Learning Rate
SGD$1/d_{in}$$d$$1 / d$$1$$1/d^2$$1 / d$
Adam$1/d_{in}$$1$$1 / d$$1 / d$$1/d^2$$1 / d$
Muon$1/d_{in}$$\sqrt{d}$$1 / d$$1$$1/d^2$$1 / \sqrt{d}$

Here, $\boldsymbol{W}_k$ refers to all parameters except $\boldsymbol{W}_{in}$ and $\boldsymbol{W}_{out}$. It’s also important to emphasize that the relationships here are “proportional to” rather than “equal to.” Additionally, in practice, slight modifications can be made based on specific requirements. For example, when we actually use Muon, the optimization of $\boldsymbol{W}_{in}$ and $\boldsymbol{W}_{out}$ typically uses Adam instead of Muon, which will lead to two changes:

  • $\eta_{out}\propto 1/d$;
  • $\eta_{in}\propto 1$;

If combined with the Adjust LR mentioned in our paper 《Muon is Scalable for LLM Training》, then the learning rate needs to be multiplied by an additional factor of $\sqrt{\max(n, m)}$. $n\times m$ is the shape of the parameter matrix. We have already assumed that the parameters in the $\text{NN}$ part always scale proportionally, so $\sqrt{\max(n, m)}\propto \sqrt{d}$. Therefore, to counteract the scale effect introduced by Adjust LR, it is necessary that:

  • $\eta_k\propto 1/\sqrt{d}$.

Summary (formatted)
#

This article introduces muP (Maximal Update Parametrization) in the most concise and clear way possible. This work aims to study the transfer laws of hyperparameters across model scales. Based on muP, we can carefully search for hyperparameters (mainly learning rate and initialization here) on small models at a relatively low cost, and then transfer them to large models, thereby reducing the training cost for large models.

Objectively speaking, the introduction and analysis here are still relatively preliminary. For example, bias terms were not considered, the generality of conclusions for architectures other than MLPs was not evaluated, nor were the effects of Normalization and residual connections thoroughly examined. Not considering bias terms was purely out of laziness; consider it an exercise for the readers. As for muP in different architectures, analysis is generally more troublesome, but due to the similarity of neural networks, the conclusions are roughly the same, and we can use them without rigorous proof. I personally believe that more crucial areas for improvement are the influences of Normalization and residual connections, especially Normalization, which allows for stable forward propagation without relying on special initialization, bringing greater freedom and possibilities.

Of course, these are left for future analysis.

@online{kexuefm-10770,
        title={A Preliminary Look at muP: Hyperparameter Scaling Laws Across Model Scales},
        author={苏剑林},
        year={2025},
        month={03},
        url={\url{https://kexue.fm/archives/10770}},
}