A short article about Muon, prompted by a question about the Kimi Moonlight paper.
Premise#
Assume you have an arbitrary PPxFSDPxTP’d model.
For the purposes of the optimizer, PP can be ignored, and FS/TP can be seen as double-sharding parameters.
So really, we can just imagine each parameter \(W\) is sharded into \(w_1, … w_w\) shards (which may/not be contiguous or etc, but are distributed evenly across ranks). Each parameter shard will have a corresponding grad shard \(g_i\).
Also, for the purposes of this post, I will ignore the role of small parameters (like norms or modulation).
- The correct analysis for them is more latency driven than bw/comp driven.
- Ideally you have conditional code to keep them unsharded && rely on log-scaled allreduce to minimize grad sync time && keep optim update simple.
Variants#
Here are all variants I thought of in the last 4 hours.
Replication#
Dumbest baseline: duplicate momentum && compute on all GPUs.
On each GPU, the following compute is done (per parameter): $$ G = \text{allgather}(g_i) \newline M_t = \mu M_{t-1} + G\newline X_t = \mu M_t + G\newline O_t = \text{NS}(X_t)\newline W_t = W_{t-1} + \eta_tO_t\newline $$ This is a stupid baseline, because it requires
- extra comms (as much as FSDP backwards does)
- extra compute – both scalar momentum/param update && matmuls in NS
NS Replication#
Sharding the update for \(M_t\) and \(W_t\), while keeping the \(O_t =\ …\ \) step replicated, is free iff the dtype of \(g_i\) matches that of \(M_t\).
- For PyTorch FSDP, iff the underlying base model parallelized has fp32 master weights (which should be true-by-default), and the momentum is fp32 (which is also typically true), the dtype will match.
Let \(m_i\) be the local shard of \(M\). For simplicity, I omit the \(_i\) subscript where its presence is implied.
On each GPU (rank \(i\)), the following compute is done: $$ m_t = \mu m_{t-1} + g_i\newline x_t = \mu m_t+g_i\newline X_t = \text{allgather}(x_t)\newline O_t = \text{NS}(X_t)\newline w_t = w_{t-1} + \eta_tO_t[i] $$
This is the baseline in fsdp_optimizers
, muon_fsdp2.py
, and Algorithm 1 in Kimi Moonlight.
It has:
- identical comms requirements to the fully replicated case (conditional on dtype)
- sharded m/w scalar compute, but still fully replicated matmul compute
Can we do better?
A note about collectives#
Before we continue, let me briefly dig into some basics about collective comms (under certain assumptions).
Both reduce-scatter and all-gather NCCL collectives are generally implemented using the same ring-reduction strategy for large tensors (approx global tensor data size \(S > 128\text{KiB}\)). Letting \(ws\) be the world size of a collective, the ring-reduce approach:
- has \(O(ws)\) growth in latency, which heuristically becomes infeasible around 100-1000 GPUs
- has a total I/O requirement \(\lim_{ws\to\infty} \text{bus volume} = S\),
- and therefore total bus bandwidth requirement of \(lim_{ws\to\infty}B = S/t\)
Because I am GPU poor, I will ignore the latency considerations. If you are not GPU poor, it is possible for larger fat-tree clusters to apply faster log-n algorithms (derivative of Bruck’s) – but this is unfair to TPU torus users, so IMO you should ignore it in general.
In contrast, an all-reduce collective with large \(ws\) can be done with a \(O(log\ ws)\) tree reduction algorithm, or even \(O(1)\) with both NVLS + IBSharp. Its bus bandwidth requirement is \(2S/t\) in the limit without SHArP and is a much smaller number I don’t know the exact numerics of with SHArP.
- I have heard that many clusters keep IBSharp disabled for stability reasons, so I will assume that SHArP is never applied beyond intranode NVSwitch level, and that it is therefore acceptable to round off the impact of SHArP to a constant dividing factor
Algo | Latency | Bus vol |
---|---|---|
AG/RS (ring) | \(O(ws)\) | \(S\) |
AR (tree) | \(O(log\ ws)\) | \(2S\) |
AR (idealized sharp) | \(O(1)\) (large const factor) | \(? « S\) |
Also, some basic numbers to consider for typical H100 3.2Tbps IB cluster, for bf16:
Source | FLOP/s | Bidirectional Bus Bandwidth (intra) | Bi Bus Bw (inter) |
---|---|---|---|
On-paper | \(0.9894*10^{15}\) | \(0.45*10^{12}\) elements-per-second | \(0.025*10^{12}\) E/s |
Real | ~\(0.7*10^{15}\) | ~\(0.37*10^{12}\) E/s | idk |
It suffices to care about bidirectional bandwidth only in most cases, because both ring/tree algorithms involve bidirectional comms for most steps. This doesn’t apply for SHArP, but I will ignore SHArP.
TP Newton-Schulz#
To fully shard the compute requirements across all accelerators, it suffices to implement the $\text{NS}$ method with sharded matmuls. The main compute loop implemented is:
for _ in range(N):
A = X @ X.T
B = b * A + c * A @ A
X = a * X + B @ X
When $X$ is a row-sharded (Shard(-2)
, i.e. Shard(0)
for a matrix) DTensor, the compute here becomes perfectly sharded by default as well.
X @ X.T
->Shard(0) @ Shard(1)
-> double-sharded -> auto-allgather toShard(0)
b*A
,c*A
are both sharded scalar compute, because A isShard(0)
c*A @ A
->Shard(0) @ Shard(0)
-> auto-allgather toShard(0)
b*A + c*A @ A
->Shard(0) + Shard(0)
->Shard(0)
a*X + B@X
->Shard(0) + Shard(0)@Shard(0)
->Shard(0)
(after auto-allgather)
So, per NS loop, 3 allgathers of the provided input are required per step.
The external loop then becomes: $$ m_t = \mu m_{t-1} + g_i\newline x_t = \mu m_t + g_i \newline o_t = \text{NS-TP}(x_t)\newline w_t = w_{t-1} + \eta_to_t $$
In summary,
- all matmul/scalar compute in the entire optimizer is perfectly sharded,
- Instead of a single allgather per parameter, \(3N\) all-gathers are required as extra comms.
Intuitively, we may expect that the required comms for this approach is unacceptable. Noting that a basic full FSDP trainer will require 3 RS/AG collectives (forward all-gather, backward all-gather + reduce-scatter) per parameter, the TP approach (if done at full world size) effectively bloats the required comms per training step by \(N\) times.
However, the above reasoning does not take into account the possibility of overlapped communications. In theory, it should be possible to unroll the required \(\text{NS-TP}\) steps across all parameters into a series of allgatherAsync
, matmul
, and mul
ops, where
allgatherAsync
is executed in a separate communication stream,- each
matmul
is scheduled by events to only occur after a givenallgatherAsync
. There is a bijection from each unique matmul to a unique allgather op.
Unfortunately, it is highly unlikely that the above overlap will be favorable to compute time. Assume \(X_t\) is shaped \([R,C]\), where \(R \le C\). Then,
- The required comms-per-rank for
Shard(0)->Replicate()
is generally \(RC\) or \(R^2\). - The required flops-per-rank for a matmul is \(2R^2C\) or \(2R^3\) in all cases. The former happens twice, and the latter once.
Using jax-ml’s roofline notation, the arithmetic intensity of a given allgather+matmul pair is always \(2R\). So:
Arithmetic intensity | Accelerator Intensity (intra) | Accel Int (inter) |
---|---|---|
2R | ~2000 | >28000 |
1 | ~1000/R | >14000/R |
For \(\text{NS-TP}\) to not be comm-bottlenecked in the ideal overlapped async impl, the average smallest dim \(R\) (assuming pow2) of each 2d parameter must be at least 1024
(intranode) or 16384
(internode).
This approach is thus obviously not optimal beyond 1node.
Rank-conditional Newton-Schulz#
The simple choice to shard compute better is to only compute \(O_t\) on a single accelerator, and broadcast the results to the rest: $$ m_t = \mu m_{t-1} + g_i\newline x_t = \mu m_t+g_i\newline r = \text{selectRank}(m_t\text{.shape})\newline X_t = \text{gatherAsync}(x_t, r)\newline O_t = \begin{cases}\text{NS}(X_t) & \text{if }r=i\newline \text{None} &\text{if }r\neq i\end{cases}\newline O_t = \text{broadcastAsync}(O_t, r)\newline w_t = w_{t-1} + \eta_tO_t[i] $$ Here, I assume you have some means of ensuring gather/broadcast ops are called at appropriate moments on the communication stream without blocking compute. If there was a magical event loop to expose an async interface to cuda ops, it’d look like this:
async def muon_naive(m_prev: DTensor, w_last: DTensor, g_shard: DTensor, mu: float, lr: float, r: int):
m_t = mu * m_prev + g_shard
x_t = mu * m_t + g_shard
X_t = gatherAsync(x_t, r)
i = dist.get_rank()
O_t = NS(await X_t) if r == i else torch.empty(x_t.shape)
O_t = await broadcastAsync(O_t, r)
w_t = w_last + lr * distribute_tensor(O_t, ...)
In practice, you can achieve something similar by looping through all parameters and enqueuing async work with appropriate waits:
from typing import Callable
import torch.cuda as cu
from torch import Tensor as TT
def apply_async(f: Callable[[TT],TT], t: TT, s_comm: cu.Stream, r: int) -> TT:
i,ws = dist.get_rank(),dist.get_world_size()
s_comp: cu.Stream = cu.default_stream()
with cu.stream(s_comm):
out = [torch.empty_like(t) for _ in range(ws)]
dist.gather(t, out, r)
if i == r: out = torch.cat(out)
if i == r:
# only i == r will have forced syncronization of both streams to each other.
s_comp.wait_stream(s_comm)
out = f(out)
s_comm.wait_stream(s_comp)
else: out = torch.empty(ws,*t.shape,device=t.device,dtype=t.dtype).flatten(0,1)
with cu.stream(s_comm):
dist.broadcast(out, r)
# output is only safe after comm is finished.
s_comp.wait_stream(s_comm)
return out.unflatten(0,(ws,-1)) # <-- safe for usage on default stream
mu, lr = ...
for w,g,m,r in optim.param_iterator(): # <-- selectRank process to be determined
m.mul_(mu).add_(g)
g.add_(m, alpha=mu)
# create a unique comm stream per parameter.
o = apply_async(NS, g, torch.cuda.Stream(), r)
w.add_(o, alpha=-lr)
The above code should cover the rough idea (note that it is not stable and will blow up memory because it will try to async gather all params all at once; add a prefetch cap to resolve that. it will also have some unnecessary compute bubbles)
If you implement the above in a more sensible way, you should obtain:
- perfect memory sharding,
- decent compute sharding (not 100% perfect because params will be of different size && it is impossible to perfectly distribute compute by param filtering)
- Slightly more than 2x comms relative to the NS Replication approach. broadcast/gather both use similar ring algos to reducescatter/allgather && require similar bus volumes per collective.
Seems pretty good. Is that the end?
Fused-FSDP Overlap#
There is one extra trick we can use to potentially reduce the required comm overhead of various approaches to zero, under certain arithmetic intensity regimes.
Let’s think back to our basic NS Replicated algorithm:
$$
m_t = \mu m_{t-1} + g_i\newline
x_t = \mu m_t+g_i\newline
X_t = \text{allgather}(x_t)\newline
O_t = \text{NS}(X_t)\newline
w_t = w_{t-1} + \eta_tO_t[i]
$$
Assume that hell freezes over, and torch begins to expose a simple API to adjust the behavior of the FSDP pre-fwd/post-bwd hooks per parameter (rather than per FSDPModule
which you would probably have to do in current torch).
Consider: In the FSDP post-backward hook, we can add code to compute \(x_t\), storing it under the param’s sharded grad:
def fsdp_bwd(p: FSDPParam):
p.shard() # or whatever method to reduce-scatter p's grads
m_t = optim.get_momentum_from_fqn(p.fqn)
m_t.mul_(optim.mu).add_(x_t := p.sharded_param.grad)
x_t.add_(m_t, alpha=optim.mu)
Then, in the pre-forward hook, we can add code to allgather \(X_t\) calculate \(O_t\), and update \(w_t\) before the param gather:
def fsdp_fwd(p: FSDPParam):
# NOTE: possibly you will want to do all of the below on the FSDP comm stream
X_t = funcol.allgather(p.sharded_param.grad)
p.zero_grad() # or however to delete all grads from p
o_t = NS(X_t)[dist.get_rank()]
p.sharded_param.data.add_(o_t, alpha=-optim.lr)
p.unshard() # unshard params for actual fwd
The effective change here is that,
- the bwd becomes slightly more compute heavy. This could improve comm-comp overlap.
- the fwd has an extra allgather && param update.
- If FSDP was already comms-bottlenecked, this is no different from the basic NS replicated case.
- If FSDP was not comms-bottlenecked, and in fact has sufficient arithmetic intensity to support 2x comms, then the cost of the allgather is hidden and the optimizer overhead becomes purely compute driven.
Of course, this is just a sketch, and would be decently hard to implement with current FSDP2 without deep knowledge of internals :(
Rank-conditional FSDP-Fused#
The above argument can also work for the rank-conditional approach. It just requires some amount of care, because
- the code must allow for multiple
FSDPParam
s to be processed in async for rank-conditional execution to be useful, - the code must not enqueue more than a single FSDPModule’s worth of parameters to do async.
A standard llama transformer block has only 9 unique parameters. This greatly limits the potential of rank-conditional execution + FSDP.
TP#
In the case of 2D FSDP (internode) x TP (intranode), the case is fairly compelling. AsyncTP should be efficient enough to parallelize NS intranode, and the outer FSDP hooks don’t change in appearance.
Conclusion#
there are probably more variants but someone else can discuss them