In our previous section, we introduced how IB SHARP works. Furthermore, NVIDIA introduced the third generation NVSwitch in their Hopper architecture machines. Similar to IB SHARP between machines, NVLink SHARP (abbreviated as nvls) can be executed within machines through NVSwitch. In this section, we’ll explain how NVLink SHARP works.
For simplicity, we’ll use examples with nranks=2, but it’s worth noting that nvls isn’t actually used when nranks=2.
During initialization, it mainly determines if nvls is supported. If supported, nvlsSupport is set to 1, and comm->nvlsChannels is set. Note that this represents the number of blocks that the kernel actually launches, not the number of channels found during the search.
Then the search process begins. nvls channels are different from channels found by other algorithms. Let’s look at the specifics:
ncclResult_tncclTopoSearchTryNvls(structncclTopoSystem*system,structncclTopoGraph*graph,structncclTopoGraph*saveGraph,intg,intngpus,int*time){structncclTopoNode*nvs;structncclTopoNode*gpu;intd0=0;// See if there is enough bandwidth for NVS->GPU traffic
do{NCCLCHECK(ncclTopoFollowPath(system,graph,NVS,0,GPU,d0,d0==g?2:1,&gpu));d0++;}while(gpu&&d0<system->nodes[GPU].count);if(gpu==NULL){d0--;}else{intd1=0;// See if there is enough bandwidth for GPU->NVS traffic
do{NCCLCHECK(ncclTopoFollowPath(system,graph,GPU,d1,NVS,0,d1==g?2:1,&nvs));d1++;}while(nvs&&d1<system->nodes[GPU].count);if(nvs==NULL){d1--;}else{// Both directions worked. Move on to the next path.
NCCLCHECK(ncclTopoSearchRecGpu(system,graph,saveGraph,NULL,ngpus,-1,-1,0,time));}while(d1){d1--;NCCLCHECK(ncclTopoFollowPath(system,graph,GPU,d1,NVS,0,d1==g?-2:-1,&nvs));}}while(d0){d0--;NCCLCHECK(ncclTopoFollowPath(system,graph,NVS,0,GPU,d0,d0==g?-2:-1,&gpu));}returnncclSuccess;}
This determines if the bandwidth meets requirements. Let’s first look at the actual GPU and NVSwitch topology in a machine as shown in Figure 1
Assuming the bandwidth in the current search condition is bw, and g is GPU0, the logic for searching one channel is to determine if the bidirectional bandwidth from all GPU nodes to NVSwitch is greater than bw. If greater, then subtract bw. Specifically for GPU0, it needs to check if the existing link bandwidth is greater than 2 * bw, which we’ll explain later.
Then ncclTopoSearchRecGpu continues executing. Note that step is specified as ngpus, so this completes the search for one channel. nChannels becomes 1, so the next time ncclTopoSearchTryGpu executes, it will start from GPU1, repeating this process until ngpus channels are found. Taking four GPUs as an example, the searched channels are shown as follows:
cpp1 lines hidden
1
2
3
4
0123
While channels in other algorithms represent the transmission order within nodes, nvls channels are different. For example, the 0 in the first channel, which NCCL calls nvlsHead, indicates which node is responsible for operations like reduce for a certain memory segment, as we’ll see later.
staticncclResult_tconnectNvls(structncclComm*comm,int*nvlsHeads,structncclTopoGraph*nvlsGraph){intnHeads=nvlsGraph->nChannels;intheadRank=-1;for(inth=0;h<nHeads;h++){if(nvlsGraph->intra[h*comm->localRanks]==comm->rank)headRank=h;}for(intc=0;c<comm->nvlsChannels;c++){structncclChannel*channel=comm->channels+c;channel->nvls.nHeads=nHeads;for(inth=0;h<nHeads;h++)channel->nvls.up[h]=comm->nRanks+1+h;for(inth=nHeads;h<NCCL_MAX_NVLS_ARITY;h++)channel->nvls.up[h]=-1;channel->nvls.down=comm->nRanks+1+headRank;channel->nvls.out=-1;// NVLS+SHARP not yet implemented.
channel->nvls.headRank=headRank;channel->nvls.treeUp=channel->nvls.treeDown[0]=channel->nvls.treeDown[1]=channel->nvls.treeDown[2]=-1;channel->nvls.node=comm->node;channel->nvls.nNodes=comm->nNodes;}if(comm->nNodes==1)returnncclSuccess;}
Calculate headRank, which indicates which channel’s node is this rank, then begin setting up all nvls channels. The up and down here are used to index peers. Since nvls connections start from nRanks+1, nRanks+1 needs to be added here. Up represents all heads, down is headRank, which is actually itself.
The overall memory registration process is shown in Figure 4. First, a multicast object is created through cuMulticastCreate, where the handle in the figure points to this multicast object. Then each GPU associates the current device with this multicast object through cuMulticastAddDevice, then allocates memory, and finally associates the allocated memory with the handle through cuMulticastBindAddr or cuMulticastBindMem.
buffSize is the buff size for the SIMPLE protocol, memSize is used to save head and tail, then calculate how much memory needs to be allocated in total. nHeads is the number of channels found, which is nRanks. We’ll see later why the memory size is configured this way. Then save the total memory size and localRanks information to resources.
rank0 executes nvlsGroupCreate, creates a multicast object through cuMulticastCreate, and saves it in resources->mcHandle. Since it needs to be shared across processes, it needs to be converted to a shareable handle, which is done through direct memcpy here.
Then all ranks execute bootstrapIntraNodeBroadcast, rank0 broadcasts the shared handle of the multicast object to all ranks’ shareableHandle. After other ranks receive the shareableHandle, they convert it to mcHandle through nvlsGroupConnect. Then through cuMulticastAddDevice, the current card is bound to mcHandle, so all ranks now have the multicast object corresponding to mcHandle.
ncclResult_tnvlsGroupBindMem(structncclComm*comm,structncclNvlsSharedRes*resources){size_tsize=resources->size;size_tgranularity;CUdeviceptrptr=0;CUmemAllocationPropprop;memset(&prop,0,sizeof(prop));prop.type=CU_MEM_ALLOCATION_TYPE_PINNED;prop.location.type=CU_MEM_LOCATION_TYPE_DEVICE;prop.location.id=resources->dev;prop.requestedHandleTypes=NVLS_CU_MEM_HANDLE_TYPE;CUCHECK(cuMemGetAllocationGranularity(&granularity,&prop,CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));resources->ucGran=granularity;// Map a VA for UC memory
CUCHECK(cuMemAddressReserve(&ptr,size,granularity,0U,0));// Alloc local physical mem for this NVLS group
CUCHECK(cuMemCreate(&resources->ucHandle,size,&prop,0));CUCHECK(cuMemMap(ptr,size,0,resources->ucHandle,0));CUCHECK(cuMemSetAccess(ptr,size,&resources->accessDesc,1));CUDACHECK(cudaMemset((void*)ptr,0,size));resources->ucBuff=(char*)ptr;CUCHECK(cuMulticastBindMem(resources->mcHandle,0/*mcOffset*/,resources->ucHandle,0/*memOffset*/,size,0/*flags*/));returnncclSuccess;}
Then begin allocating physical memory and mapping it to virtual address space. First, reserve a virtual address space to ptr, then allocate physical memory to ucHandle, then map the physical memory pointed to by ucHandle to ptr, and assign ptr to ucBuff, as shown in Figure 5.
Finally, bind the physical memory corresponding to ucHandle to mcHandle through cuMulticastBindMem.
Then begin executing nvlsGroupMapMem to map mcHandle to virtual address space.
cpp14 lines hidden
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
ncclResult_tnvlsGroupMapMem(structncclComm*comm,structncclNvlsSharedRes*resources){size_tsize=resources->size;CUdeviceptrptr=0;// Create a VA for the NVLS
CUCHECK(cuMemAddressReserve(&ptr,size,resources->granularity,0U,0));// Map the VA locally
CUCHECK(cuMemMap(ptr,size,0,resources->mcHandle,0));resources->mcBuff=(char*)ptr;INFO(NCCL_NVLS,"NVLS Mapped MC buffer at %p size %zi",resources->mcBuff,size);// Having completed the BindMem we can now call SetAccess
// NB: It will block until all ranks have bound to the Group
CUCHECK(cuMemSetAccess((CUdeviceptr)resources->mcBuff,size,&resources->accessDesc,1));returnncclSuccess;}
Similarly, reserve virtual address space to ptr, then map mcHandle to ptr, save in mcBuff. At this point, it’s as shown in Figure 6
At this point, this physical memory is mapped to both ucBuff and mcBuff. ucBuff is Unicast buffer, access to it only affects the current device’s memory. mcBuff is Multicast buffer, access to it will be broadcast by NVSwitch to all devices added to mcHandle.
Then begin recording the memory to each peer’s connection buffer.
If GPU0 executes multimem.ld_reduce on peer[0]->send[1].buff, this will load data from the corresponding positions of GPU0 and GPU1, perform reduce, and save the result in d.
Using ReduceScatter as an Example to Explain the Kernel Process#
Version 2.19 has undergone significant changes, so before looking at the nvls kernel, let’s first introduce how the new kernel execution process works.
Let’s first look at how the kernel is launched and how it executes step by step to reach the corresponding device functions for protos, algos, etc.
Due to multiple APIs, reduce types, data types, algorithms, and protocols, where kernels are the Cartesian product of these variables, NCCL uses generate.py to generate these kernel definitions, mainly generating two arrays: ncclDevKernelForFunc and ncclDevFuncTable, for global functions and device functions respectively.
During the enqueue process, workFuncIndex is calculated through computeColl, then kernelFn is recorded as ncclDevKernelForFunc[workFuncIndex]. Taking ReduceScatter sum as an example, the workFuncIndex is 485, and looking up ncclDevKernelForFunc[485] gives us ncclDevKernel_ReduceScatter_Sum_f32_RING_LL. Note that while we’re actually using the SIMPLE protocol, this kernel is LL. Let’s see how it’s further forwarded.
ncclShmem is located in shared memory and stores the parameter information needed by the kernel, such as channelId, comm, channel, etc. All threads in a block will use this information for sending and receiving data. In previous versions, all threads in a block had the same peer, while in the new version different threads may correspond to different peers. For example, in send/recv, a block can send/receive to/from 8 peers, and in nvls discussed in this section, different warps in a block use a pipelined approach to complete the overall process. Therefore, the data structure ncclShmemGroup groups was introduced, where a group represents threads executing the same logic, and the required conn, srcs, dsts information for the group is stored in groups.
The correspondence between blocks and channels is selected, calculating which channel the current block should handle.
cpp15 lines hidden
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
template<intSpecializedFnId,typenameSpecializedRunWork>__device__voidncclKernelMain(structncclDevComm*comm,uint64_tchannelMask,structncclWork*workHead){...while(true){// Notify host that all fifo reads are complete.
...if(0<=SpecializedFnId&&ncclShmem.work.header.funcIndex==(unsigned)SpecializedFnId){SpecializedRunWork().run(&ncclShmem.work);}else{ncclDevFuncTable[ncclShmem.work.header.funcIndex]();}intworkIxNext=ncclShmem.work.header.workNext;__syncthreads();...}...}
With funcIndex being 485 and SpecializedFnId being 483, it will look up the corresponding function in ncclDevFuncTable again, which is ncclDevFunc_ReduceScatter_Sum_f32_RING_SIMPLE, thus finding the function to execute.
DEFINE_ncclDevFunc(ReduceScatter_Sum_f32_RING_SIMPLE,ncclFuncReduceScatter,FuncSum,float,NCCL_ALGO_RING,NCCL_PROTO_SIMPLE)#define DEFINE_ncclDevFunc(suffix, coll, redop, ty, algo, proto) \
__device__ void ncclDevFunc_##suffix() { \
RunWork<coll, ty, redop<ty>, algo, proto>().run(&ncclShmem.work); \
}
template<ncclFunc_tFn,typenameT,typenameRedOp,intAlgo,intProto>structRunWork{// This __forceinline__ is necessary. The compiler was inserting a function call
// here from the LL ncclKernel.
__device____forceinline__voidrun(ncclWork*w){intwid=threadIdx.x/WARP_SIZE;ncclWorkElem*we=w->header.type==ncclWorkTypeRegColl?&w->regElems[0].elem:&w->elems[0];intstride=w->header.type==ncclWorkTypeRegColl?sizeof(ncclWorkElemReg):sizeof(ncclWorkElem);#pragma unroll 1
while((char*)we+stride<=(char*)(w+1)&&we->isUsed){if(wid<we->nWarps){RunWorkElement<Fn,T,RedOp,Algo,Proto>().run(we);}we=(ncclWorkElem*)((char*)we+stride);}}};
Looking at the definition of this function, we can see that Fn is ncclFuncReduceScatter, T is float, RedOp is FuncSum<float>, algo is NCCL_ALGO_RING, and protocol is NCCL_PROTO_SIMPLE, then it begins executing runRing
template<typenameT,typenameRedOp,typenameProto>__device____forceinline__voidrunRing(ncclWorkElem*args){...constssize_tloopSize=nChannels*chunkSize;constssize_tsize=args->count;Primitives<T,RedOp,FanSymmetric<1>,0,Proto,0>prims(tid,nthreads,&ring->prev,&ring->next,args->sendbuff,args->recvbuff,args->redOpArg);for(ssize_tgridOffset=0;gridOffset<size;gridOffset+=loopSize){ssize_trealChunkSize;...realChunkSize=int(realChunkSize);ssize_tchunkOffset=gridOffset+bid*int(realChunkSize);/// begin ReduceScatter steps ///
ssize_toffset;intnelem=min(realChunkSize,size-chunkOffset);intrankDest;// step 0: push data to next GPU
rankDest=ringRanks[nranks-1];offset=chunkOffset+rankDest*size;prims.send(offset,nelem);
First step, execute send for the block data corresponding to its own rank, which means sending data to the next rank’s buffer.
cpp3 lines hidden
1
2
3
4
5
6
// k-2 steps: reduce and copy to next GPU
for(intj=2;j<nranks;++j){rankDest=ringRanks[nranks-j];offset=chunkOffset+rankDest*size;prims.recvReduceSend(offset,nelem);}
Then for the next nranks - 2 steps, execute recvReduceSend, which means reducing the data sent from the previous rank in its buffer with the corresponding position in its user input data, then sending it to the next rank.
cpp4 lines hidden
1
2
3
4
5
6
7
// step k-1: reduce this buffer and data, which will produce the final result
rankDest=ringRanks[0];offset=chunkOffset+rankDest*size;prims.recvReduceCopy(offset,chunkOffset,nelem,/*postOp=*/true);}}}
Finally, execute recvReduceCopy, which means reducing the data sent from the previous rank with the corresponding position in its user input and copying it to the user output.
Reviewing the construction of primitives in ReduceScatter, recvPeers is the previous rank in the ring, sendPeers is the next rank in the ring, inputBuf and outputBuf are the input/output buffers provided by the user executing the API, group is the default parameter 0. redOpArg is used for operations like mean, where it would be set to nranks and divided by nranks during reduceCopy. In this example of sum operation, redOpArg can be ignored.
In the template parameters, Direct and P2p are 0, Fan is FanSymmetric<1>, which records how many recv and send operations there are, with MaxRecv and MaxArity both being 1.
__device__Primitives(...)// For send operations, we need an extra warp to overlap the threadfence and the copy
this->nworkers=nthreads-(MaxSend>0&&nthreads-WARP_SIZE>=64?WARP_SIZE:0);intnrecv=0,nsend=0;while(nrecv<MaxRecv&&recvPeers[nrecv]!=-1)nrecv++;while(nsend<MaxSend&&sendPeers[nsend]!=-1)nsend++;this->fan=Fan(nrecv,nsend);constexprintThreadPerSync=8;static_assert(MaxSend<=ThreadPerSync&&MaxRecv<=ThreadPerSync,"Not enough threads to cover all peers");intg=tid/ThreadPerSync;intng=nthreads/ThreadPerSync;index=tid%ThreadPerSync;flags=0;if(g==0){if(index<nrecv)flags|=RoleWaitRecv;if(index==nrecv)flags|=RoleInput;}elseif(g==1){if(index<nsend)flags|=RoleWaitSend;if(index==nsend)flags|=RoleOutput;}elseif(g==ng-2){if(index<nrecv)flags|=RolePostRecv;}elseif(g==ng-1){if(index<nsend)flags|=RolePostSend;}...}
nthreads is the total number of executing threads, in this example equal to the number of threads in the block. nworkers is the number of threads actually doing work. Since sending requires one warp to execute threadfence, the actual working threads are nthreads minus one warp, though when the total number of warps is small, an independent synchronization warp won’t be used. Record nsend and nrecv, both being 1 in this case.
Then begin setting each thread’s role, dividing nthreads into groups of 8. Assuming there are n-1 groups, and both recvPeer and sendPeer have two each, the role allocation for threads is shown in Figure 8, where WaitRecv indicates the thread is responsible for waiting until data is available to receive in the fifo. In this example, g[0]’s thr[0] waits for the 0th recvPeer, thr[1] for the 1st recvPeer, Input threads are responsible for writing to user buffer addresses, PostRecv is responsible for notifying recvPeer after receiving data, g[n-2]’s thr[0] notifies the 0th recvPeer, thr[1] notifies the 1st recvPeer, and similarly for send.
__device____forceinline__voidloadRecvConn(ncclDevChannelPeer*peer,intconnIndex,structncclWorkElem*e){if(flags&(RoleWaitRecv|RolePostRecv)){auto*conn=&peer->recv[connIndex];step=conn->step;step=roundUp(step,SlicePerChunk*StepPerSlice);if(flags&RolePostRecv){connStepPtr=conn->head;*connStepPtr=step;// Return credits in case we rounded up.
}if(flags&RoleWaitRecv){ncclShmem.groups[group].recvConns[index]=conn;// WaitRecv role saves since that's who needs it in setDataPtrs()
flags|=(conn->flags&NCCL_NVLS_MIN_POLL)?NvlsMinPolling:0;connStepPtr=conn->tail;connStepCache=loadStepValue(connStepPtr);flags|=(conn->offsFifo!=nullptr)?OffsFifoEnabled:0;if(Direct){...}if(flags&OffsFifoEnabled)connOffsFifoPtr=conn->offsFifo;connEltsFifo=(T*)conn->buffs[NCCL_PROTO_SIMPLE];}}}
Only threads with RoleWaitRecv and RolePostRecv execute loadRecvConn, reading the step which, as mentioned in previous sections, represents the position in the fifo. RolePostRecv threads are responsible for notifying recvPeer, so they need to save the head pointer from conn to connStepPtr. RoleWaitRecv threads are responsible for waiting until new data is in the fifo, so they need to save the tail pointer from conn to connStepPtr and cache the content in connStepCache to avoid frequent global memory reads. Finally, conn->buff (the fifo) is recorded in connEltsFifo.
__device____forceinline__voidloadSendConn(ncclDevChannelPeer*peer,intconnIndex,structncclWorkElem*e){if(flags&(RoleWaitSend|RolePostSend)){auto*conn=&peer->send[connIndex];step=conn->step;step=roundUp(step,SlicePerChunk*StepPerSlice);if(flags&RolePostSend){connStepPtr=conn->tail;connEltsFifo=(T*)conn->buffs[NCCL_PROTO_SIMPLE];}if(flags&RoleWaitSend){ncclShmem.groups[group].sendConns[index]=conn;// WaitSend role saves since that's who needs it in setDataPtrs()
flags|=(conn->flags&NCCL_NVLS_MIN_POLL)?NvlsMinPolling:0;connStepPtr=conn->head;connStepCache=loadStepValue(connStepPtr);flags|=(conn->offsFifo!=nullptr)?OffsFifoEnabled:0;if(flags&OffsFifoEnabled)connOffsFifoPtr=conn->offsFifo;connEltsFifo=(T*)conn->buffs[NCCL_PROTO_SIMPLE];...}}}
loadSendConn follows the same logic: RolePostSend threads are responsible for notifying send peer so they hold the tail pointer, RoleWaitSend threads are responsible for waiting for send peer so they hold the head pointer, then record the fifo.
Finally, setDataPtrs is executed to set userBuff as the user’s input and output.
cpp5 lines hidden
1
2
3
4
5
6
7
8
__device__voidsetDataPtrs(voidconst*inputBuf,void*outputBuf,uint64_tredOpArg,structncclWorkElemReg*e){if(flags&RoleInput){userBuff=(T*)inputBuf;ncclShmem.redOpArgs[0]=redOpArg;// scaler for local input
}if(flags&RoleOutput)userBuff=(T*)outputBuf;...}
In the template parameters, Recv indicates whether recv needs to be executed, Send indicates whether Send needs to be executed, SrcBuf indicates whether the input contains user’s src buff, and DstBuf indicates whether the output contains user’s dst buff. Then calculate to get DirectRecv and DirectSend as 0, Src as 1, and Dst as 0.
In version 2.7.8, the worker threads responsible for data transmission and the synchronization threads are in the same loop, which introduces many branch instructions affecting performance. In the new version, the logic is split into two loops to improve performance, with worker threads executing the first loop and synchronization threads executing the second loop.
The RoleInput thread fills the user buff into srcs[0], then executes waitPeer. The waitPeer function is the previous waitSend and waitRecv, which will wait until data can be sent and received, and fill the data addresses into srcs and dsts.
cpp12 lines hidden
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
template<intDirectRecv,intDirectSend,intRecv,intSend,intSrc,intDst>__device____forceinline__voidwaitPeer(intptr_tsrcIx,intptr_tdstIx,intoffset,intnelts){constboolisSendNotRecv=(Send&&Recv)?(flags&RoleWaitSend):Send;constboolnoRecvWait=DirectRecv&&Src&&(flags&DirectRead);// no wait when directly reading from remote input
constboolnoSendWait=DirectSend&&(flags&(DirectRead|DirectWrite));// no wait in empty send (e.g. directScatter) or direct remote write
if(((flags&(Recv*RoleWaitRecv))&&!noRecvWait)||((flags&(Send*RoleWaitSend))&&!noSendWait)){intspins=0;while(connStepCache+(isSendNotRecv?NCCL_STEPS:0)<step+StepPerSlice){connStepCache=loadStepValue(connStepPtr);if(checkAbort(spins))break;}}...}
Both noRecvWait and noSendWait are 0. For the RoleWaitSend thread, isSendNotRecv is 1, and since it holds the connStepPtr as head pointer, its waiting logic is that if the head pointer plus queue capacity is less than step + StepPerSlice, it cannot execute send as it would exceed queue capacity, so it waits in a loop. For the RoleWaitRecv thread, isSendNotRecv is 0, and since it holds the connStepPtr as tail pointer, its waiting logic is that if step + StepPerSlice exceeds the tail pointer, it means there’s no data in the queue, so it needs to wait.
Then it starts filling the srcs and dsts arrays by putting the slot corresponding to its own fifo in, and updates the step. So for recvReduceSend, srcs[0] is the user buff, srcs[1] is the fifo of the previous rank, dsts[0] is the fifo of the next rank, thus achieving the previously described function: receiving data from the previous rank, executing reduce with the user’s input buff, then sending to the next rank.
template<intDirectRecv1,intDirectSend1,intRecv,intSend,intSrcBuf,intDstBuf>__device____forceinline__voidgenericOp(intptr_tsrcIx,intptr_tdstIx,intnelem,boolpostOp){...if(tid<nworkers&&offset<nelem){do{...subBarrier();intworkSize=ncclShmem.aborted?0:sliceSize;if(DirectRecv&&ncclShmem.groups[group].srcs[0]==ncclShmem.groups[group].dsts[0]}elseif(DirectSend&&!DirectRecv&&SrcBuf!=Input&&ncclShmem.groups[group].dsts[Dst]==nullptr){}else{constexprintPreOpSrcs=SrcBuf!=Input?0:DirectRecv*MaxRecv==NCCL_MAX_DIRECT_ARITY?(1+NCCL_MAX_DIRECT_ARITY):1;reduceCopy<Unroll,RedOp,T,MultimemSrcs,Recv+Src,Recv*MaxRecv+Src,MultimemDsts,Send+Dst,Send*MaxSend+Dst,PreOpSrcs>(tid,nworkers,ncclShmem.redOpArgs[0],ncclShmem.redOpArgs,postOp,Recv*fan.nrecv()+Src,ncclShmem.groups[group].srcs,Send*fan.nsend()+Dst,ncclShmem.groups[group].dsts,workSize);}barrier();// This barrier has a counterpart in following loop
postPeer<Recv,Send>(0<sliceSize);offset+=sliceSize;slice+=1;}while(slice<SlicePerChunk&&offset<nelem);}while(slice<SlicePerChunk){sliceSize=sliceSize<nelem-offset?sliceSize:nelem-offset;barrier();// Has couterpart in preceding worker-only loop.
postPeer<Recv,Send>(0<sliceSize);offset+=sliceSize;slice+=1;}}
After waitPeer completes, indicating data transmission can begin, it first executes subBarrier() to synchronize all worker threads, ensuring they enter the data transmission logic only after waitPeer completes. Then it executes reduceCopy to perform reduce from srcs and copy to dsts. Then it executes barrier() for all threads, including worker and synchronization threads, because synchronization threads can only start post after data transmission ends. Let’s look at the postPeer executed by synchronization threads.
RolePost threads need to update step and write step to connStepPtr. For RolePostRecv holding the head pointer, it can write directly. For RolePostSend holding the tail pointer, to ensure data writing completes before post, it needs a fence. Here it uses acq_rel barrier, though release semantics would be sufficient for this scenario, but checking PTX shows there’s no separate release semantic instruction. For data reading scenarios, matching read barriers are also needed, but nccl’s implementation uses volatile which can bypass L1 cache, so barriers aren’t needed.
The scatter thread executes scatter operation through prim, with sendPeers as up, therefore including all ranks. inputBuf is the user’s input args->sendbuff, and connIndexSend is 1, so it loads the 1st send conn.
__device____forceinline__voidscatter(intptr_tinpIx,ssize_ttotalElem,intpeerElem,ssize_tpeerOffset,intskip,intshift){ScatterGatherOp<0,0,0,1>(inpIx,-1,totalElem,peerElem,peerOffset,skip,shift,/*postOp=*/false);}template<intDirectRecv1,intDirectSend1,intRecv,intSend>__device____forceinline__voidScatterGatherOp(intptr_tinpIx,intptr_toutIx,ssize_ttotalElem,intpeerElem,ssize_tpeerOffset,intskip,intshift,boolpostOp){constexprintDirectRecv=1&&Direct&&DirectRecv1;constexprintDirectSend=1&&Direct&&DirectSend1;intoffset=0;// slice offset
intsliceSize=stepSize*StepPerSlice;intdataSize=max(DIVUP(peerElem,16*SlicePerChunk)*16,sliceSize/32);// per-peer slice size
#pragma unroll
for(intslice=0;slice<SlicePerChunk;++slice){ssize_trealSize=max(0,min(dataSize,peerElem-offset));boolfenceNeeded=false;if(tid<nworkers){if(Send){// Scatter pre-scales data of input buffer only in non-Direct case
constexprintPreOpSrcs=DirectSend?0:1;if(flags&RoleInput)ncclShmem.groups[group].srcs[0]=userBuff+inpIx+offset;// realSize is not accurate here; but intra-node does not rely on sizes FIFO
waitPeer<0,DirectSend,0,1,1,0>(0,inpIx,offset,realSize);subBarrier();#pragma unroll
// Loop over peers
for(intj=0;j<fan.nsend();j++){inti=(j+shift)%fan.nsend();ssize_tpOffset=i*peerOffset;// Skip the data I am responsible of reducing myself
if(skip>=0&&i>=skip)pOffset+=peerElem;void*src0=(T*)ncclShmem.groups[group].srcs[0]+pOffset;ssize_trealPeerSize=min(realSize,totalElem-pOffset);if(realPeerSize>0&&ncclShmem.groups[group].dsts[i]!=nullptr){reduceCopy<Unroll,RedOp,T,0,1,1,0,1,1,PreOpSrcs>(tid,nworkers,ncclShmem.redOpArgs[0],ncclShmem.redOpArgs,false,1,&src0,1,ncclShmem.groups[group].dsts+i,realPeerSize);// Mark for threadfence at the end
fenceNeeded|=true;}}}elseif(Recv){}}fenceNeeded=barrierAny(fenceNeeded);postPeer<Recv,Send>(fenceNeeded);offset+=realSize;}}
As shown in Figure 9, scatter takes the data from userBuff and sends it to all sendPeer corresponding buffers (peer[0]->send[1].buff and peer[1]->send[1].buff) at peerOffset intervals.
For reduce threads, sendPeers is NULL, recvPeers is nvls->down, connIndexRecv is 0, so it loads the 0th recv conn and executes recv.
cpp9 lines hidden
1
2
3
4
5
6
7
8
9
10
11
12
elseif(tid<tidEndReduce){// Reduce through NVLS
usingProto=ProtoSimple<1,1,COLL_UNROLL,1,0>;Primitives<T,RedOp,FanAsymmetric<1,0>,/*Direct=*/0,Proto,0>prims(tid-tidEndScatter,nThreadsReduce,&nvls->down,NULL,NULL,args->recvbuff,args->redOpArg,3*Proto::MaxGroupWidth,0,0);for(ssize_tgridOffset=0;gridOffset<size;gridOffset+=loopSize){ssize_toffset=gridOffset+bid*chunkSize;intnelem=min(chunkSize,size-offset);prims.recv(offset,nelem);}}
In the recv function, dst is set to args->recvbuff, src is the Multicast buffer corresponding to its own rank. As shown in Figure 10, after execution, GPU0’s recvBuff gets the reduced results corresponding to all cards.
template<intUnroll,typenameRedFn,typenameT,intMultimemSrcs,intMinSrcs,intMaxSrcs,intMultimemDsts,intMinDsts,intMaxDsts,intPreOpSrcs,typenameIntBytes>__device____forceinline__voidreduceCopy(intthread,intnThreads,uint64_tredArg,uint64_t*preOpArgs,boolpostOp,intnSrcs,void**srcPtrs,intnDsts,void**dstPtrs,IntBytesnElts){intlane=thread%WARP_SIZE;// If a multimem src is present then our biggest pack size is limited to what
// is supported for this redfn/type.
constexprintBigPackSize=(MultimemSrcs==0)?16:LoadMultimem_BigPackSize<RedFn>::BigPackSize;IntBytesnBytesBehind=0;IntBytesnBytesAhead=nElts*sizeof(T);#if __cpp_if_constexpr
ifconstexpr(BigPackSize>sizeof(T)){#else
if(BigPackSize>sizeof(T)){#endif
// Check that all pointers are BigPackSize aligned.
boolaligned=true;if(lane<nSrcs)aligned&=0==cvta_to_global(srcPtrs[lane])%(BigPackSize+!BigPackSize);if(lane<nDsts)aligned&=0==cvta_to_global(dstPtrs[lane])%(BigPackSize+!BigPackSize);aligned=__all_sync(~0u,aligned);if(aligned){reduceCopyPacks<RedFn,T,Unroll,BigPackSize,MultimemSrcs,MinSrcs,MaxSrcs,MultimemDsts,MinDsts,MaxDsts,PreOpSrcs>(nThreads,/*&*/thread,redArg,preOpArgs,postOp,nSrcs,srcPtrs,nDsts,dstPtrs,/*&*/nBytesBehind,/*&*/nBytesAhead);if(nBytesAhead==0)return;reduceCopyPacks<RedFn,T,/*Unroll=*/1,BigPackSize,MultimemSrcs,MinSrcs,MaxSrcs,MultimemDsts,MinDsts,MaxDsts,PreOpSrcs>(nThreads,/*&*/thread,redArg,preOpArgs,postOp,nSrcs,srcPtrs,nDsts,dstPtrs,/*&*/nBytesBehind,/*&*/nBytesAhead);if(nBytesAhead==0)return;}}...}
In the template parameters, MultimemSrcs indicates how many inputs are multimem, MultimemDsts indicates how many outputs are multimem. In the parameters, thread is the tid, nThreads is the total number of threads, there are nSrcs inputs with addresses in srcPtrs, nDsts outputs stored in dstPtrs, and nElts is the number of elements. The function’s purpose is to reduce all src and store them to all dst.
nBytesBehind indicates how much data has been processed, nBytesAhead indicates how much data remains unprocessed.
Then it checks if vectorized instructions can be used. BigPackSize is the granularity of load/store instructions. If the input is non-Multimem, it tries to use 16 bytes (128 bits); if it’s multimem, it needs to check Func and data type - in this case it’s FuncSum, so it’s also 16 bytes.
cpp12 lines hidden
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
template<typenameFn>structLoadMultimem_BigPackSize{usingT=typenameFn::EltType;staticconstexprboolIsSum=std::is_same<Fn,FuncSum<T>>::value||std::is_same<Fn,FuncPreMulSum<T>>::value||std::is_same<Fn,FuncSumPostDiv<T>>::value;staticconstexprboolIsMinMax=std::is_same<Fn,FuncMinMax<T>>::value;staticconstexprboolIsFloat=IsFloatingPoint<T>::value;staticconstexprintBigPackSize=IsFloat&&IsSum&&sizeof(T)<8?16:IsFloat&&IsSum?8:IsFloat&&IsMinMax&&sizeof(T)==2?16:!IsFloat&&(IsSum||IsMinMax)&&sizeof(T)>=4?sizeof(T):/*multimem.ld_reduce not supported:*/0;};
For vectorization, inputs and outputs need to be aligned. Let’s look at the reduceCopyPacks logic using alignment as an example.
cpp16 lines hidden
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
template<typenameRedFn,typenameT,intUnroll,intBytePerPack,intMultimemSrcs,intMinSrcs,intMaxSrcs,intMultimemDsts,intMinDsts,intMaxDsts,intPreOpSrcs,typenameIntBytes>__device____forceinline__voidreduceCopyPacks(intnThreads,int&thread,uint64_tredArg,uint64_t*preOpArgs,boolpostOp,intnSrcs,void**srcPtrs,intnDsts,void**dstPtrs,IntBytes&nBytesBehind,IntBytes&nBytesAhead){// A hunk is the amount of contiguous data a warp consumes per loop iteration
// assuming all threads partake.
constexprintBytePerHunk=Unroll*WARP_SIZE*BytePerPack;intnWarps=nThreads/WARP_SIZE;intwarp=thread/WARP_SIZE;intlane=thread%WARP_SIZE;...}
BytePerPack is BigPackSize, which is 16 bytes. Assuming Unroll is 4, a warp’s memory access pattern is shown in Figure 11. A blue box is 32 16-byte segments, BytePerHunk is the continuous data length processed by a warp at once (4 blue boxes). The first thread in the warp will access the first 16 bytes in the 4 blue boxes pointed to by the arrow.
__device____forceinline__voidreduceCopyPacks(...){// This thread's initial position.
IntBytesthreadBytesBehind=nBytesBehind+(warp*BytePerHunk+lane*BytePerPack);IntBytesthreadBytesAhead=nBytesAhead-(warp*BytePerHunk+lane*BytePerPack);// Number of hunks to be consumed over all warps.
IntBytesnHunksAhead=nBytesAhead/(BytePerHunk+!BytePerHunk);// Advance collective position.
nBytesBehind+=nHunksAhead*BytePerHunk;nBytesAhead-=nHunksAhead*BytePerHunk;if(Unroll==1&&BytePerPack<=nBytesAhead){// Only Unroll=1 can do partial hunks (where not all threads partake).
nHunksAhead+=1;nBytesBehind+=nBytesAhead-(nBytesAhead%(BytePerPack+!BytePerPack));nBytesAhead=nBytesAhead%(BytePerPack+!BytePerPack);}nHunksAhead-=warp;RedFnredFn(redArg);uintptr_tminSrcs[MinSrcs+!MinSrcs];uintptr_tminDsts[MinDsts+!MinDsts];#pragma unroll
for(ints=0;s<MinSrcs;s++)minSrcs[s]=cvta_to_global(srcPtrs[s])+threadBytesBehind;#pragma unroll
for(intd=0;d<MinDsts;d++)minDsts[d]=cvta_to_global(dstPtrs[d])+threadBytesBehind;...}
threadBytesBehind is the current thread’s starting position, threadBytesAhead is the amount of data the current thread needs to process, then record MinSrcs src and minDsts dst pointers to minSrcs and minDsts.
__device____forceinline__voidreduceCopyPacks(...){...while(Unroll==1?(BytePerPack<=threadBytesAhead):(0<nHunksAhead)){BytePack<BytePerPack>acc[Unroll];{RedFnpreFn(0<PreOpSrcs?preOpArgs[0]:0);#pragma unroll Unroll
for(intu=0;u<Unroll;u++){if(0<MultimemSrcs){// applyLoadMultimem uses relaxed semantics for same reason we use volatile below.
acc[u]=applyLoadMultimem<RedFn,BytePerPack>(redFn,minSrcs[0]);}else{// Use volatile loads in case credits are polled for with volatile (instead of acquire).
acc[u]=ld_volatile_global<BytePerPack>(minSrcs[0]);if(0<PreOpSrcs)acc[u]=applyPreOp(preFn,acc[u]);}minSrcs[0]+=WARP_SIZE*BytePerPack;}}...}
BytePack in this scenario is 16 bytes, described by a union, acc is used to store reduce results.
Then begin initializing acc. If there’s no multimem in the input, it will load 128b corresponding to the first blue box to acc[0] through ld_volatile_global, then loop Unroll times to load corresponding data from all blue boxes to acc. Let’s look at ld_volatile_global.
Since BytePerPack is 16, it will execute ld_volatile_global<16>, which is actually ld.volatile.global.v2.b64 loading 128b to ans’s u64[0] and u64[1].
When the input includes multimem, it will use applyLoadMultimem to load data, reduce it and store it to acc; st_global<16> stores BytePack’s u64[0] and u64[1] to addr using st.global.v2.b64.
We can see it uses multimem.ld_reduce.relaxed.sys.global.add.v4.f32 to execute reduce operations on 4 floats and store the result in acc.
After completing the first src read, continue reading other src to tmp, then execute reduce operation through Apply_Reduce, which performs elementwise sum on 4 floats.
__device____forceinline__voidreduceCopyPacks(...){...while(Unroll==1?(BytePerPack<=threadBytesAhead):(0<nHunksAhead)){...#pragma unroll (MinSrcs-1 + !(MinSrcs-1))
for(ints=1;s<MinSrcs;s++){BytePack<BytePerPack>tmp[Unroll];RedFnpreFn(s<PreOpSrcs?preOpArgs[s]:0);#pragma unroll Unroll
for(intu=0;u<Unroll;u++){if(s<MultimemSrcs){// applyLoadMultimem uses relaxed semantics for same reason we use volatile below.
acc[u]=applyLoadMultimem<RedFn,BytePerPack>(redFn,minSrcs[s]);}else{// Use volatile loads in case credits are polled for with volatile (instead of acquire).
tmp[u]=ld_volatile_global<BytePerPack>(minSrcs[s]);}minSrcs[s]+=WARP_SIZE*BytePerPack;}#pragma unroll Unroll
for(intu=0;u<Unroll;u++){if(s<PreOpSrcs)tmp[u]=applyPreOp(preFn,tmp[u]);acc[u]=applyReduce(redFn,acc[u],tmp[u]);}}for(ints=MinSrcs;(MinSrcs<MaxSrcs)&&(s<MaxSrcs)&&(s<nSrcs);s++){uintptr_tsrc=cvta_to_global(srcPtrs[s])+threadBytesBehind;BytePack<BytePerPack>tmp[Unroll];RedFnpreFn(s<PreOpSrcs?preOpArgs[s]:0);#pragma unroll Unroll
for(intu=0;u<Unroll;u++){// Use volatile loads in case credits are polled for with volatile (instead of acquire).
tmp[u]=ld_volatile_global<BytePerPack>(src);src+=WARP_SIZE*BytePerPack;}#pragma unroll Unroll
for(intu=0;u<Unroll;u++){if(s<PreOpSrcs)tmp[u]=applyPreOp(preFn,tmp[u]);acc[u]=applyReduce(redFn,acc[u],tmp[u]);}}...}...}
This completes the reduceCopy process. Note that in the nvls scenario, head, tail and other flags are mcBuff. Let’s look at how waitPeer determines if waiting is needed.
cpp9 lines hidden
1
2
3
4
5
6
7
8
9
10
11
12
inline__device__uint64_tloadStepValue(uint64_t*ptr){#if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010
if(flags&NvlsMinPolling){uint64_tans;asm("multimem.ld_reduce.acquire.sys.global.min.u64 %0, [%1];":"=l"(ans):"l"(cvta_to_global(ptr)));returnans;}#endif// volatile is faster than acquire but not as correct. Make sure reduceCopy
// loads data using volatile so it doesn't see stale data in L1.
returnld_volatile_global(ptr);}
In the nvls scenario, flags will have NvlsMinPolling. Here it uses multimem.ld_reduce to read all peers’ steps and takes the min, so it only receives data when all peers are ready. It uses acquire semantics here, paired with postPeer’s release to ensure memory ordering.
Here it uses non-multimem instructions to write mcBuff. The PTX manual says this is undefined behavior, but officially it’s acceptable for write operations.
Like ring allreduce, nvls loop size is nranks * chunkSize, which is variable nlem or loopSize. Scatter threads are responsible for sending data from sendbuff to nvls->up, connIndexSend is 1, so it uses the first send conn. After execution, it looks like Figure 12.
elseif(tid<tidEndReduce&&nvls->headRank!=-1){if(!hasOut){// Reduce, broadcast through NVLS
usingProto=ProtoSimple<1,1,COLL_UNROLL,1,1>;Primitives<T,RedOp,FanSymmetric<1>,/*Direct=*/1,Proto,0>prims(tid-tidEndGather,nThreadsReduce,&nvls->down,&nvls->down,NULL,NULL,args->redOpArg,2*Proto::MaxGroupWidth,0,0,args);for(ssize_tgridOffset=0;gridOffset<size;gridOffset+=loopSize){ssize_toffset=gridOffset+(bid*nvls->nHeads+nvls->headRank)*chunkSize;intnelem=min(chunkSize,size-offset);prims.directRecvDirectSend(offset,offset,nelem);}}...}
Here MultimemSrcs and MultimemDsts are 1, using the 0th conn. After executing directRecvDirectSend, the effect is as shown in Figure 13. GPU 0 follows the yellow arrows, reducing the light yellow data blocks from both cards to get the dark yellow data block, then broadcasts the data to both cards using multimem.st. Similarly, GPU 1 follows the green arrows, reducing light blue data blocks to get the dark blue data block. At this point, both cards have the global data.
The gather process is ScatterGatherOp executing the recv branch, which we won’t elaborate on. After execution, it looks like Figure 14, completing the allreduce.
As mentioned earlier, when searching for channels, the bandwidth corresponding to head needs to be multiplied by 2. Taking Figure 15 as an example and recalling the Allreduce process, when GPU 0 executes ld_reduce, the bandwidth consumption is as shown in Figure 16