This is a companion longpost for a fun project I’ve yet to finish. In here, I show the reader how I personally visualize the collective communications involved in a simple 2⁶ 6D parallel mesh:
There are many articles that describe various training parallelisms in vague words and simple visuals. Most of them fail to convey a deep understanding of the exact communications involved in a single training step, and even for the outliers that do, they do not cover the more complex case of combining all approaches.
So, in this post, I will attempt to bring you, the reader, through what changes in the execution of forward/backward passes as various parallelisms are applied to an extremely simple model architecture.
(0D) Single GPU case#
In every researchers’ fantasies, there exists the Perfect GPU. It has infinite memory, infinite FLOPs, and never communicates with any other devices.
In that world, the forward-backward loop of a training run looks like this:
Which is quite boring. Nobody would try to deny scaling laws in a self-deceiving attempt to cover up their own engineering inadequacies in that world.
Let’s add something more.
(1D) - Data Parallel#
Everyone understands how Distributed Data Parallelism works at a superficial level.
- Identical copies of the model exist on every accelerator.
- Each accelerator receives a different batch of data to fwd-bwd on.
- Before an optimizer step, the gradients are synced across accelerators via an AllReduce.
But what of the details? Quoting PyTorch:
Despite the conceptual simplicity of the technique, the subtle dependencies between computation and communication make it non-trivial to optimize the distributed training efficiency.
As of v1.5, PyTorch natively provides several techniques to accelerate distributed data parallel, including bucketing gradients, overlapping computation with communication, and skipping gradient synchronization.
A tip for life: it is possible to make anything arbitrarily complex. So, I make the following simplifications, which I claim are fair for educational purposes:
- I treat the implementation of the AllReduce collective op as a singular black box function call. In particular, I pay no attention to exactly which messages of what size are passed when under NCCL’s adaptive algorithm. I suggest reading this page if you seek a more concrete understanding.
- I assume bucketing is done on a per-layer basis (like FSDP2), rather than by size thresholds (which DDP and deepspeed both implement). Reasonably, it’s possible to set the size thresholds to be equivalent to that of a layer.
- I assume no use of gradient accumulation, which implies synchronization is never skipped.
With the above in mind, the visualization is updated to the following:
Conceptually, a coalesced allreduce is applied over the parameters for each layer after its backward \(B_i\) is executed. The computation of \(B_{i-1}\) overlaps with that of the synchronization of layer \(i\)’s gradients. The last synchronization has no (model) compute to overlap with, and executes as the optimizer step occurs.
As this is 1D parallelism, there is only one group of devices, communicating across a single ’line’. Casual pytorch users should be familiar with this setup, as it matches the behavior of a simple training run with only a single global distributed process group.
(2D) - Hybrid/Fully Sharded Data Parallel#
Outside of special circumstances, most people with less than 1000 GPUs will only ever need Fully Sharded Data Parallel (FSDP) to train models. As Meta reports, basic Llama 7B training retains reasonable MFU up to around \(8*2^6=512\) GPUs. Most people are more familiar with FSDP than anything else.
However, the FSDP/deepspeed APIs are designed with magical, “it just works”, experiences in mind.
As a result, many users of FSDP do not understand the broad strokes of how its sharding is implemented, much less in detail. Frontier LLMs seem to not know much, either (try it!). I myself chose to ignore the details for the longest time, often feeling exasperated as the black box behaved in ways I did not understand.
To give a simple motivating example: in pure DP trainers, it is common to see rank-conditional code used for evaluation/logging purposes, e.g.
if dist.get_rank() == 0:
with torch.no_grad():
sample = model.generate(...)
FSDP users will recognize this as a code smell – something liable to produce NCCL deadlocks, when model
is FSDP-wrapped.
This is sensible: the ZeRO3 algorithm shards parameters across accelerators, and those parameters are required for generating with the forward pass. If you wanted rank-conditional evaluation code to run correctly, you would have to disable parameter sharding, and apply ZeRO2/1 instead.
Except, that doesn’t always work. Different implementations of ZeRO2 may/not require comms to occur in future forward passes:
- DistributedFusedAdam launches parameter syncs within the optim step (which are blocking-by-default), and doesn’t require any distributed comms in a forward pass.
- torch’s implementation of ZeRO2 (i.e.
SHARD_GRAD_OP
/reshard_after_forward=False
) puts parameter synchronization in the forward pass, blocking rank-conditional execution. Deepspeed doesn’t.
For communication efficiency (as sizeof(optim) < sizeof(params)
often), most implementations of ZeRO2 actually use local optim shards to update local parameter shards, before allgathering params to sync. So, ZeRO2, despite the advertising, actually still has sharded parameters that have to get gathered.
To avoid getting bogged down in the many different ways in which FSDP could be implemented, my visualization only demonstrates a single, specific case: 2D Hybrid Sharded Data Parallelism, in which
- Pytorch’s approach to ZeRO2 is applied over one axis of devices,
- DDP, as described in the previous section, is applied over another axis.
In Pytorch terms, each GPU has access to:
- the shape of the full
2x2
Device Mesh, - process groups (lines) for each parallelism, through which collective communications occur.
Before each layer \(i\)’s forward \(F_i\), their parameters must be allgathered, as the compute for \(F_i\) requires the full parameters for layer \(i\) to exist on-device. After all forward steps are done, each accelerator has a full copy of the parameters, which get used immediately in the future backwards steps. Each completed backward \(B_i\) triggers:
- a reduce-scatter of local, full gradients, to sharded gradients synchronized on the
FS
axis, - an additional all-reduce of each gradient shard to sync on the
DP
axis - (a deletion of the full params for layer \(i\) into a param shard, to save memory)
Try to remember how this works, because it becomes important in Pipeline Parallelism later.
Interruption - Layer architecture#
DDP/FSDP are mostly architecture agnostic, and the past 3 visualizations apply sensibly for any model that involves a simple sequential computational graph.
Other parallelisms are not so abstract. The communications involved in tensor/context/expert parallelism change greatly depending on the architecture involved.
So, for future visualizations here, I assume the use of a pure Attention layer (or MoMHA in the EP case) repeated L times, i.e.
import torch.nn.functional as F
from torch import nn, Tensor
Linear = lambda a,b: nn.Linear(a,b, bias=False)
class Att(nn.Module):
def __init__(self, d: int, h: int, *, d_h: int=128):
super().__init__()
self.wk = Linear(d,h*d_h)
self.wv = Linear(d,h*d_h)
self.wq = Linear(d,h*d_h)
self.wo = Linear(h*d_h,d)
self.d_h = d_h
def forward(self, x: Tensor) -> Tensor:
q,k,v = self.wq(x), self.wk(x), self.wv(x)
q,k,v = (t.unflatten(-1,(-1,self.d_h)) for t in [q,k,v])
return self.wo(F.scaled_dot_product_attention(q,k,v).flatten(-2))
def create_model(L: int, D: int=1024):
return nn.Sequential(*[Att(D, D // 128) for _ in range(L)])
This is only done to simplify the visualization – in practice, nobody should create models stacked purely with attention layers, and especially with no norms. I exclude such things to keep the visualization as clean as possible, but an excited reader can choose to imagine, e.g. a separate MLP layer in which CP comms do not occur, or norm layers in which additional TP allgathers happen.
Moving on,
(3D) - Tensor Parallel#
The model architecture described above, fits the pattern of combined column+rowwise parallelism. Consequently, the only add-on to the visualization is a simple blocking allreduce call at the end of every layer:
In this visualization, each square on the 2x2 grid represents a GPU Pair, making the full grid itself a single 8x GPU node. The inner dot within each square lights up whenever a TP communication occurs.
Why would you ever do this? Consider the simple case of an 8x3090 node, where:
- Pairs of GPUs are NVLink’d, making TP viable,
- Groups of 4x GPUs are separated across a NUMA Boundary.
Most engineers should find this disturbing and misleading. After all, real models will involve a reduceScatter -> norm -> allGather pattern, rather than a single blocking allReduce. Furthermore, that pattern may even be further optimized to a series of interleaved P2P reads. Also, no effort has been made to understand the tail circumstances in which an unpaired rowwise/colwise parallel layer may be needed.
To this, I say:
- it’s accurate to the architecture given in the previous section.
- the visualization would not change much if the allReduce was replaced by a reduceScatter+pause+allGather. The duration would be almost the same.
- I think readers would understand less if the TP component involved many tiny repeated slices of comms that lit up for 0.01s each.
Still, I fully understand if a skilled reader feels disgusted by the simplified presentation. I can only apologize for my limited capabilities.
(4D) - Context Parallel#
As with the last section, my visualization of Context Parallelism only makes sense in the context (really?) of an Attention layer. When an attention layer receives inputs sharded on the sequence dimension, it must transport individual KV (and grads of) shards between different accelerators.
In general, you can extrapolate the concept of “Context Parallelism” to any architecture that operates over the sequence dimension. For instance, a SoftMoE layer might require similar comms to obtain its dispatch weights / input slots when context parallelized.
For architectures that treat the sequence dim as an extra batch dimension, Context Parallelism is meaningless. So, if I had used a simple FFN layer as the target architecture instead, the diagram below would simply have no CP comms in it:
The Full Mesh
in the embed shows the relative hierarchies of the parallelisms involved. Because context/tensor parallelism are more liable to block compute kernels, they’re placed at the bottom of the hierarchy, communicating with devices that are closer. Each tile will also display its device index if you hover over them – it makes more sense to keep FS comms lower latency.
In this visualization, I follow the Llama3/Pytorch practice of implementing context parallelism via a simple Allgather, rather than adopting the irecv/isend approach used in ring attention variants.
This is a key decision that directly impacts the visualization used. While the allgather (or all2all) approach I show blocks the compute stream, a ring-attention like approach would instead have its communications overlapped async, with repeated interleaved P2P ops.
I picked the former as it is the ’trap’ a PyTorch user will fall into by default, if they apply the context_parallel
API as is done in torchtitan.
(5D) - Expert Parallel#
Expert parallelism is an edge case of an edge case.
Certain layers in certain models will be MoE layers. Regardless of their routing strategy, almost all MoE architectures:
- have routed experts of the same parameter count / compute requirements, which can be split across devices,
- do not maintain the same routing order across layers, thus necessitating an immediate gather/scatter of inputs across devices,
- place their routers right before expert compute, making non-blocking comms implausible (though an async-TP-like overlap is plausible)
In those cases, the following visualization makes sense:
Per layer,
- a blocking all2all is used to send expert-bucketed inputs to the accelerator(s) on which the experts exist,
- another all2all is used to pass expert outputs back to their original data parallel processes.
This is basically the same as Deepspeed-TED, except that I put the 2nd all2all before TP’s allreduce, rather than after, because it should save comm costs if top-k > 1 MoE is used.
I also ignore the reality that the displayed approach above is inefficient for an MoMHA layer. Typical FFN MoE layers involve blocking all2all calls, as the router can only be called directly before the experts. In the case of MoMHA, it should be possible to overlap comms as so:
def forward(self, x: Tensor) -> Tensor:
logits = self.router(x)
x_routed = self.permute_and_all2all(x,logits) # <- some async op
k,v = self.wk(x), self.wv(x)
q = self.q_experts(x_routed)
# ... (rest of attention)
But it would probably be more confusing to show that? ¯\_(ツ)_/¯
(6D) - Pipeline Parallel#
“Was it wise to pipeline? As we now know, pipelining is not wise. But we were not as wise back then…”
Everyone hates pipeline parallelism. It infects/destroys otherwise clean training code, creates gigantic bubbles of compute inactivity, and worst of all, is conceptually simple enough to make an engineer feel it should be easy to implement.
The SOTA in pipelining is ZBPP, which promises something close to perfect pipelining, at the cost of your entire codebase + my entire visualization setup thus far. Although PyTorch allegedly supports ZeroBubble, the last time I tried it I received a segmentation fault in libtorch only when torch.compile was enabled.
Therefore, instead of attempting to render that, I am merely showcasing the perspective of a single PP group with 2 microbatches under a simple GPipe schedule instead. The theoretical shape of that pipeline looks like:
0 : fF___bbBB
1 : _fFbbBB__
So, for Rank 1, the visualization has appropriate bubbles pausing execution at the beginning && end. Appropriate p2p comms are added where Rank 1 should be receiving/sending activations from/to Rank 0:
As you might have noticed at the beginning, the pipeline stage tabs in the iframe above are fake cosmetics. This is a skill issue and I encourage more capable people to publish better visualizations if at all possible.
But there are some less noticeable choices here I want to highlight.
Mesh order#
If you hover over the device indices of the full mesh in detail, you may notice that the parallelism hierarchy rises in the order of [TP, EP, CP, FS, DP, PP]
– GPUs are closest along TP lines, and furthest along PP lines.
That hierarchy isn’t important for a mere 2⁶ GPUs; but it could make sense at scale:
- \(TP=8\) applied within-node, with low-latency async TP over NVSwitch.
- \(EP=16\) across rail-optimized leaf switches for best all2all perf
- \(CP\times FS\times DP=256\), such that the 5D submesh
[TP,EP,CP,FS,DP]
fills a 32k island- \(CP > FS > DP\) in terms of latency priority
- \(PP=?\) across islands
It could also be wrong: the llama3 paper argues that FSDP should rest at the highest mesh dim. But I’m skeptical after having written the following section:
Pipelining and FSDP#
Pipeline parallelism necessitates gradient accumulation by design. A single step has to be split into multiple forward/backward steps to prevent pipelining from being laughably inefficient.
Most parallelisms do not hate each other. Mesh parallelism exists as a concept simply because it is easy to compose FSDP, TP, CP, EP, etc. in mostly harmonious ways.
But pipelining and FSDP are at odds. If you allowed for a naive implementation of FSDP to be used, each forward and backward microbatch in a pipeline schedule would require its own allgather/reducescatter. This quickly pushes up the communication cost of FSDP by \(O(microbatches)\), which will obviously destroy MFU if e.g. \(microbatches >= 24\) as in ZBPP.
Let’s call the naive case above the Overcommunication case.
To avoid Overcommunication
, it is possible to avoid synchronization and parameter sharding to save on communication costs. Structurally, this is similar to how gradient accumulation is implemented with DDP.
But, when implemented naively, it is possible to accidentally nullify all memory savings from FSDP, making it equivalent to pure data parallelism with worse latency. I call this the Memory Bloated case:
- Currently, Pytorch’s
pipelining
submodule disablesreshard_after_backward
andrequires_gradient_sync
in FSDP backward until the last microbatch. Since PP is useless forn_microbatches <= 1
, that disabling will always happen for any train step. - That behavior blocks gradient synchronization and parameter sharding, causing the full copy of all params and grads of all layers to be present on each device (+ extra waste for shard buffers that aren’t deleted)
- If used with SGD, this should have worse peak memory than pure PP + DP. Presumably, this bug remains uncovered due to the memory savings from sharding Adam, which will still exist.
In my visualization, I avoid modelling that behavior, and instead consider what I believe is the GPipe Optimal case:
- in GPipe, because all forward steps are executed at the start, and I’m using ZeRO2, only 1x allgather of params is required per layer.
- for every backward layer step, I create new local gradients and reduce-scatter them to accumulated gradient shards. This ensures that the total memory required for storing gradients is always roughly equivalent to that of their sharded size.
- only for the last microbatch’s backward, I apply an allreduce of gradients across the DP axis. Meaning: nosync is applied for DP, but not FS.
All 3 cases are available in the modified visualization below:
Sidenote: I am unaware of any academic references for these claims. This is simply a summarization of what I learned from reading Pytorch code/issues. I recommend doing the same if you seek a full understanding of this blogpost.
Ending notes#
If you’ve made it this far (without skipping to the end), you should now:
- have some vague awareness of what collective comms each flavor of parallelism uses,
- have a working copy of my mental model of how combined 6D mesh parallelism should work,
- be aware of extraneous trivia regarding various OSS approaches to FSDP/PP,
All visuals in this post are directly available here; they’re thrown onto the blog as color-inverted iframes.
Development process#
All code for all visualizations were haphazardly generated over the course of a week by repeated prompting of o1/pro and Sonnet 3.5 in Cursor. If there is any demand for it, I can throw up the source files on Github, but IMO they are quite dirty and not worth sharing.
It was disappointing to learn over the course of my work that all public models are very far from one-shotting visualizations in the vein of this post. I remain hopeful that the next generation of models will put me into a fast and final retirement.
On correctness#
Frankly speaking, it is hard for a single person to nail 100% of all details involved with no external feedback, and I fully expect to have made multiple egregious errors in understanding.
For example, one obvious problem with all visualizations is in the absence of meaning in the relative durations of each communication. It’s not possible to define what the relative scale of, e.g. FS/DP comms vs TP/CP comms should be without adding more inputs to the visualizations, because that relative scale is dependent on the batch size : model size
ratio.
I publish this post in the hopes of receiving accurate evaluations of what is false in my understanding of distributed parallelism.