The previous section introduced the communication link establishment process. This section will cover the operation process of ncclSend and ncclRecv within a single machine.
Communication within a single machine is conducted through kernels, so the entire communication process can be divided into two steps: first, preparing kernel-related parameters, and second, the actual kernel execution process.
For ease of explanation, unless otherwise specified, the following examples will be based on a single-machine, single-thread, two-GPU scenario. Here’s the test case:
#include <stdio.h>
#include "cuda_runtime.h"
#include "nccl.h"
#include <unistd.h>
#include <stdint.h>
#define CUDACHECK(cmd) do { \
cudaError_t e = cmd; \
if( e != cudaSuccess ) { \
printf("Failed: Cuda error %s:%d '%s'\n", \
__FILE__,__LINE__,cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while(0)
#define NCCLCHECK(cmd) do { \
ncclResult_t r = cmd; \
if (r!= ncclSuccess) { \
printf("Failed, NCCL error %s:%d '%s'\n", \
__FILE__,__LINE__,ncclGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while(0)
int main(int argc, char* argv[])
{
//each process is using two GPUs
int nDev = 2;
int nRanks = nDev;
int chunk = 1024*1024;
int size = nDev * chunk;
float** sendbuff = (float**)malloc(nDev * sizeof(float*));
float** recvbuff = (float**)malloc(nDev * sizeof(float*));
cudaStream_t* s = (cudaStream_t*)malloc(sizeof(cudaStream_t)*nDev);
//picking GPUs based on localRank
for (int i = 0; i < nDev; ++i) {
CUDACHECK(cudaSetDevice(i));
CUDACHECK(cudaMalloc(sendbuff + i, size * sizeof(float)));
CUDACHECK(cudaMalloc(recvbuff + i, size * sizeof(float)));
CUDACHECK(cudaMemset(sendbuff[i], 1, size * sizeof(float)));
CUDACHECK(cudaMemset(recvbuff[i], 0, size * sizeof(float)));
CUDACHECK(cudaStreamCreate(s+i));
}
ncclUniqueId id;
ncclComm_t comms[nDev];
//generating NCCL unique ID at one process and broadcasting it to all
ncclGetUniqueId(&id);
//initializing NCCL, group API is required around ncclCommInitRank as it is
//called across multiple GPUs in each thread/process
NCCLCHECK(ncclGroupStart());
for (int i=0; i<nDev; i++) {
CUDACHECK(cudaSetDevice(i));
NCCLCHECK(ncclCommInitRank(comms+i, nRanks, id, i));
}
NCCLCHECK(ncclGroupEnd());
//calling NCCL communication API. Group API is required when using
//multiple devices per thread/process
NCCLCHECK(ncclGroupStart());
for (int i=0; i<nDev; i++) {
for (int j = 0; j < nDev; j++) {
NCCLCHECK(ncclSend((const void*)(sendbuff[i] + j * chunk), chunk, ncclFloat, j, comms[i], s[i]));
NCCLCHECK(ncclRecv((void*)(recvbuff[i] + j * chunk), chunk, ncclFloat, j, comms[i], s[i]));
}
}
NCCLCHECK(ncclGroupEnd());
//synchronizing on CUDA stream to complete NCCL communication
for (int i=0; i<nDev; i++)
CUDACHECK(cudaStreamSynchronize(s[i]));
//freeing device memory
for (int i=0; i<nDev; i++) {
CUDACHECK(cudaFree(sendbuff[i]));
CUDACHECK(cudaFree(recvbuff[i]));
}
//finalizing NCCL
for (int i=0; i<nDev; i++) {
ncclCommDestroy(comms[i]);
}
return 0;
}
Communication Parameter Preparation#
Let’s first look at the communication parameter preparation process. Before diving into details, let’s examine the overall picture.
Figure 1
At the bottom, send0 and recv0 represent the data buffers prepared by users for rank0
Let’s go through each component
p2p channel#
First, let’s see how p2p operation channels are created
ncclResult_t ncclTopoComputeP2pChannels(struct ncclComm* comm) {
comm->p2pnChannels = std::min(comm->nChannels, (int)ncclParamMaxP2pNChannels());
comm->p2pnChannels = std::max(comm->p2pnChannels, (int)ncclParamMinP2pNChannels());
int minChannels = comm->p2pnChannels;
// We need to loop through all local GPUs to have a global picture
for (int g=0; g<comm->topo->nodes[GPU].count; g++) {
for (int r=0; r<comm->nRanks; r++) {
int nChannels;
NCCLCHECK(ncclTopoGetNchannels(comm->topo, g, r, &nChannels));
if (nChannels >= 0) minChannels = std::min(minChannels, nChannels);
}
}
// Round to next pow2 nChannelsPerPeer and nChannels
comm->p2pnChannelsPerPeer = nextPow2(minChannels);
comm->p2pnChannels = nextPow2(comm->p2pnChannels);
// Init channels that weren't used so far
for (int c=comm->nChannels; c<comm->p2pnChannels; c++) NCCLCHECK(initChannel(comm, c));
// We want to spread channels used when there aren't many and progressively
// fill the whole space of nChannels. To do so we mirror the bits in the
// nChannels space.
for (int c=0; c<comm->p2pnChannelsPerPeer; c++) {
int mirror = 0;
for (int b=1, mb=(comm->p2pnChannels>>1); b<comm->p2pnChannels; b<<=1, mb>>=1) if (c & b) mirror |= mb;
comm->p2pChannels[c] = mirror;
}
INFO(NCCL_INIT, "%d coll channels, %d p2p channels, %d p2p channels per peer", comm->nChannels, comm->p2pnChannels, comm->p2pnChannelsPerPeer);
return ncclSuccess;
}
Previously, when establishing the ringGraph, a series of rings were searched and channels were created based on these rings. Assuming there are now nChannels channels in total, and p2p needs p2pnChannels channels, if p2pnChannels is greater than nChannels, an additional p2pnChannels - nChannels channels will be created, while others are reused; otherwise, they can be directly reused.
For each send/recv operation, p2pnChannelsPerPeer channels are used for parallel sending/receiving. When p2pnChannelsPerPeer is relatively small and p2pnChannels is large, this leads to only using the first few channels, unable to fully utilize all channels. For example, with p2pnChannelsPerPeer = 2 and p2pnChannels = 32, communications between rank0 and rank1, rank2 will all use channel[1] and channel[2]. To solve this issue, NCCL uses array p2pChannels[p2pnChannelsPerPeer] as an offset. For instance, if p2pChannels[0] = 0 and p2pChannels[1] = 16, then rank0 and rank1 communications will use channel[1] and channel[17], while rank0 and rank2 communications will use channel[2] and channel[18], making better use of channels.
For easier understanding, in subsequent examples we’ll assume both p2pnChannels and p2pnChannelsPerPeer are 1.
#
peerlist#
Next, let’s look at peerlist, which is actually a member of comm->p2plist. Figure 1 only shows the peerlist; see the comments below for specific meanings.
struct ncclP2Pinfo {
const void* sendbuff; // 用户指定要发送的数据buffer
void* recvbuff; // 用户指定的接收数据的buffer
ssize_t sendbytes; // sendbuff长度
ssize_t recvbytes; // recvbuff长度
};
struct ncclP2PConnect {
int nrecv[MAXCHANNELS]; // nrecv[id]表示第id个channel会recv几个rank
int nsend[MAXCHANNELS]; // nsend[id]表示第id个channel会send给几个rank
int* recv; // recv[id * nranks]开始的nrecv[id]个rank,表示第id个channel会从这几个rank recv
int* send; // send[id * nranks]开始的nsend[id]个rank,表示第id个channel会send给这几个rank
};
struct ncclP2Plist {
struct ncclP2Pinfo *peerlist;
int count;
struct ncclP2PConnect connect;
};
cudaLaunchParams#
In Figure 1, both intraParams and myParams are of type cudaLaunchParams. Communication is actually completed through kernels, and cudaLaunchParams records the kernel parameters.
struct cudaLaunchParams {
void *func;
dim3 gridDim;
dim3 blockDim;
void **args;
size_t sharedMem;
cudaStream_t stream;
};
At the end of initTransportsRank, parameters are set. intraRank0 indicates which rank is the first rank on the current machine, intraRanks indicates how many ranks are on the current machine, and intraRank indicates which rank the current rank is on the current machine.
int intraRank0 = -1, intraRank = -1, intraRanks = 0;
for (int i = 0; i < nranks; i++) {
if ((allGather1Data[i].peerInfo.hostHash == allGather1Data[rank].peerInfo.hostHash) &&
(allGather1Data[i].peerInfo.pidHash == allGather1Data[rank].peerInfo.pidHash)) {
if (intraRanks == 0) intraRank0 = i;
if (i == rank) intraRank = intraRanks;
intraRanks++;
}
}
NCCLCHECK(ncclCommSetIntra(comm, intraRank, intraRanks, allGather1Data[intraRank0].comm));
intraBarrier is used for CPU synchronization. Here we can see that both intraBarrier and intraParams are actually using intraRank0’s.
ncclResult_t ncclCommSetIntra(struct ncclComm* comm, int rank, int ranks, struct ncclComm* comm0) {
comm->intraRank = rank;
comm->intraRanks = ranks;
comm->intraPhase = 0;
// Alloc shared structures
if (rank == 0) {
assert(comm == comm0);
int* bar;
NCCLCHECK(ncclCalloc(&bar, 2));
bar[0] = bar[1] = 0;
comm->intraBarrier = bar;
NCCLCHECK(ncclCalloc(&comm->intraParams, comm->intraRanks));
NCCLCHECK(ncclCalloc(&comm->intraCudaDevs, comm->intraRanks));
int* CGMode;
NCCLCHECK(ncclCalloc(&CGMode, 1));
*CGMode = 0x11;
comm->intraCGMode = CGMode;
int* CC;
NCCLCHECK(ncclCalloc(&CC, 1));
*CC = ncclCudaCompCap();
comm->intraCC = CC;
} else {
comm->intraBarrier = (int*)waitForNonNullPtr(&comm0->intraBarrier);
comm->intraParams = (struct cudaLaunchParams*)waitForNonNullPtr(&comm0->intraParams);
comm->intraCudaDevs = (int*)waitForNonNullPtr(&comm0->intraCudaDevs);
comm->intraCGMode = (int*)waitForNonNullPtr(&comm0->intraCGMode);
comm->intraCC = (int*)waitForNonNullPtr(&comm0->intraCC);
}
comm->intraCudaDevs[comm->intraRank] = comm->cudaDev;
NCCLCHECK(initParams(comm));
int cgMdLaunch = 0;
// Set CG Mode
comm->launchMode = ncclComm::GROUP;
char* str = getenv("NCCL_LAUNCH_MODE");
if (str) INFO(NCCL_ENV, "NCCL_LAUNCH_MODE set by environment to %s", str);
if (comm->intraRanks == 1 || (str && strcmp(str, "PARALLEL") == 0)) {
comm->launchMode = ncclComm::PARALLEL;
}
if (comm->launchMode == ncclComm::GROUP) {
CUDACHECK(cudaStreamCreateWithFlags(&comm->groupStream, cudaStreamNonBlocking));
#if CUDART_VERSION >= 9000
if (*comm->intraCC && (ncclCudaCompCap() == *comm->intraCC)) {
// Check whether the GPU supports Cooperative Group Multi Device Launch
(void) cudaDeviceGetAttribute(&cgMdLaunch, cudaDevAttrCooperativeMultiDeviceLaunch, comm->cudaDev);
}
#endif
}
// Disable cgMdLaunch if any rank does not support it
if (cgMdLaunch == 0) {
*comm->intraCGMode = 0x10;
}
return ncclSuccess;
}
Then args and myParam are set through initParam, as shown in Figure 1.
ncclResult_t initParams(struct ncclComm* comm) {
struct cudaLaunchParams* params = comm->myParams = comm->intraParams+comm->intraRank;
params->args = &comm->argsptr;
params->stream = NULL;
params->sharedMem = 0;
params->blockDim.x = 0; params->blockDim.y = params->blockDim.z = 1;
params->gridDim.x = 0; params->gridDim.y = params->gridDim.z = 1;
return ncclSuccess;
}
Then ncclSend begins execution, generating ncclInfo through user parameters.
ncclResult_t ncclSend(const void* sendbuff, size_t count, ncclDataType_t datatype, int peer,
ncclComm_t comm, cudaStream_t stream) {
struct ncclInfo info = { ncclCollSendRecv, "Send",
sendbuff, NULL, count, datatype, ncclSum, peer, comm, stream, /* Args */
1, 1 };
ncclResult_t ret;
NCCLCHECK(ncclGroupStart());
ret = ncclEnqueueCheck(&info);
NCCLCHECK(ncclGroupEnd());
return ret;
}
ncclGroupStart simply increments ncclGroupMode. A non-zero ncclGroupMode indicates being in a Group operation. Operations between GroupStart and GroupEnd won’t block, and are submitted all at once through GroupEnd.
ncclResult_t ncclGroupStart() {
if (ncclGroupMode == 0) {
memset(ncclGroupArgs, 0, sizeof(struct ncclAsyncArgs)*MAX_ASYNC_OPS);
}
ncclGroupMode++;
return ncclSuccess;
}
Let’s look at ncclEnqueueCheck next
ncclResult_t ncclEnqueueCheck(struct ncclInfo* info) {
// Launch asynchronously if needed
if (ncclAsyncMode()) {
ncclResult_t ret = ncclSuccess;
int savedDev = -1;
// Check arguments
NCCLCHECK(PtrCheck(info->comm, info->opName, "comm"));
if (info->comm->checkPointers) {
CUDACHECKGOTO(cudaGetDevice(&savedDev), ret, end);
CUDACHECKGOTO(cudaSetDevice(info->comm->cudaDev), ret, end);
}
NCCLCHECKGOTO(ArgsCheck(info), ret, end);
// Always register comm even in case of error to make sure ncclGroupEnd
// cleans it up.
NCCLCHECKGOTO(ncclAsyncColl(info->comm), ret, end);
NCCLCHECKGOTO(checkSetStream(info), ret, end);
if (info->coll == ncclCollSendRecv) { //p2p stored separately
NCCLCHECKGOTO(ncclSaveP2p(info), ret, end);
} else {
NCCLCHECKGOTO(ncclSaveKernel(info), ret, end);
}
end:
if (savedDev != -1) CUDACHECK(cudaSetDevice(savedDev));
ncclAsyncErrCheck(ret);
return ret;
}
ncclGroupArgs and ncclGroupIndex are thread_local variables, indicating there are ncclGroupIndex AsyncArgs in total. Here it checks if there are AsyncArgs for the current comm in ncclGroupArgs. If not, a new one is added, setting funcType to ASYNC_FUNC_COLL and setting comm.
ncclResult_t ncclAsyncColl(ncclComm_t comm) {
struct ncclAsyncArgs* args = ncclGroupArgs;
for (int i=0; i<ncclGroupIndex; i++) {
if (args->coll.comm == comm) return ncclSuccess;
args++;
}
if (ncclGroupIndex >= MAX_ASYNC_OPS) {
WARN("Too many async operations in progress, max is %d", MAX_ASYNC_OPS);
return ncclAsyncErrCheck(ncclInvalidUsage);
}
ncclGroupIndex++;
args->funcType = ASYNC_FUNC_COLL;
args->coll.comm = comm;
return ncclSuccess;
}
Then comm->userStream is set to info->stream.
static ncclResult_t checkSetStream(struct ncclInfo* info) {
if (info->comm->userStreamSet == false) {
info->comm->userStream = info->stream;
info->comm->userStreamSet = true;
} else if (info->stream != info->comm->userStream) {
WARN("Error : mixing different streams within a group call is not supported.");
return ncclInvalidUsage;
}
return ncclSuccess;
}
Then ncclSaveP2p is executed, saving p2p-related information to comm’s p2plist. peer indicates who to send to, where delta means (rank + delta) % nranks = peer, so the corresponding channel can be found through rank + delta. p2pnChannelsPerPeer channels will execute data transmission in parallel. If the channel hasn’t established a connection with the peer yet, connection information needs to be recorded first. For example, for send on the id-th channel, peer will be recorded at send[id * nranks + nsend[id]], then nsend[id] is incremented to facilitate subsequent connection establishment logic. Finally, sendbuff and data length are recorded in the corresponding peer in peerlist, as shown in Figure 1.
ncclResult_t ncclSaveP2p(struct ncclInfo* info) {
struct ncclComm* comm = info->comm;
struct ncclP2Plist* p2plist = &comm->p2plist;
int peer = info->root;
p2plist->count++;
ssize_t nBytes = info->count*ncclTypeSize(info->datatype);
if (info->recvbuff == NULL) {
if (peer != comm->rank) {
int delta = (comm->nRanks - (comm->rank-peer)) % comm->nRanks;
for (int c=0; c<comm->p2pnChannelsPerPeer; c++) {
int channelId = (delta+comm->p2pChannels[c]) % comm->p2pnChannels;
if (comm->channels[channelId].peers[peer].send.connected == 0) {
p2plist->connect.send[channelId*comm->nRanks+p2plist->connect.nsend[channelId]++] = peer;
}
}
}
p2plist->peerlist[info->root].sendbytes = nBytes;
p2plist->peerlist[info->root].sendbuff = info->sendbuff;
} else {
if (peer != comm->rank) {
int delta = (comm->nRanks + (comm->rank-peer)) % comm->nRanks;
for (int c=0; c<comm->p2pnChannelsPerPeer; c++) {
int channelId = (delta+comm->p2pChannels[c]) % comm->p2pnChannels;
if (comm->channels[channelId].peers[peer].recv.connected == 0) {
p2plist->connect.recv[channelId*comm->nRanks+p2plist->connect.nrecv[channelId]++] = peer;
}
}
}
p2plist->peerlist[info->root].recvbytes = nBytes;
p2plist->peerlist[info->root].recvbuff = info->recvbuff;
}
return ncclSuccess;
}
Then ncclGroupEnd begins execution. Since ncclGroupMode is non-zero at this point, it returns directly, completing ncclSend execution.
ncclResult_t ncclGroupEnd() {
if (ncclGroupMode == 0) {
WARN("ncclGroupEnd: not in a group call.");
return ncclInvalidUsage;
}
ncclGroupMode--;
if (ncclGroupMode > 0) return ncclSuccess;
...
}
Next is the ncclRecv process, which is identical to ncclSend. After execution, recv-related information is also saved to p2plist.
ncclResult_t ncclRecv(void* recvbuff, size_t count, ncclDataType_t datatype, int peer,
ncclComm_t comm, cudaStream_t stream) {
struct ncclInfo info = { ncclCollSendRecv, "Recv",
NULL, recvbuff, count, datatype, ncclSum, peer, comm, stream, /* Args */
1, 1 };
ncclResult_t ret;
NCCLCHECK(ncclGroupStart());
ret = ncclEnqueueCheck(&info);
NCCLCHECK(ncclGroupEnd());
return ret;
}
Then ncclGroupEnd begins execution. Previously, ncclSend and ncclRecv wrote related information to p2plist. The first step now is to establish connections if they don’t exist.
ncclResult_t ncclGroupEnd() {
if (ncclGroupMode == 0) {
WARN("ncclGroupEnd: not in a group call.");
return ncclInvalidUsage;
}
ncclGroupMode--;
if (ncclGroupMode > 0) return ncclSuccess;
int savedDev;
CUDACHECK(cudaGetDevice(&savedDev));
int activeThreads = 0;
int doneArray[MAX_ASYNC_OPS];
for (int i=0; i<ncclGroupIndex; i++) doneArray[i] = 1;
ncclResult_t ret = ncclGroupError;
if (ret != ncclSuccess) goto group_cleanup;
/* Launch async ncclCommInitRank */
...
for (int i=0; i<ncclGroupIndex; i++) {
struct ncclAsyncArgs* args = ncclGroupArgs+i;
if (args->funcType == ASYNC_FUNC_COLL) {
struct ncclP2Plist* p2plist = &args->coll.comm->p2plist;
if (p2plist->count != 0) {
struct ncclComm* comm = args->coll.comm;
args->coll.connect = 0;
for (int c=0; c<comm->p2pnChannels; c++)
args->coll.connect += comm->p2plist.connect.nsend[c] + comm->p2plist.connect.nrecv[c];
if (args->coll.connect) {
pthread_create(ncclGroupThreads+i, NULL, ncclAsyncThreadPreconnect, args);
}
}
}
}
for (int i=0; i<ncclGroupIndex; i++) {
struct ncclAsyncArgs* args = ncclGroupArgs+i;
if (args->funcType == ASYNC_FUNC_COLL && (args->coll.connect)) {
int err = pthread_join(ncclGroupThreads[i], NULL);
if (err != 0) {
WARN("Error waiting for pthread_join : %s\n", strerror(errno));
return ncclSystemError;
}
NCCLCHECKGOTO(args->ret, ret, end);
}
}
...
}
A thread is started for each AsyncArgs to execute ncclAsyncThreadPreconnect. Here ncclTransportP2pSetup needs to be executed for each p2p channel, with nsend, send, and other related information recorded in p2plist.
void* ncclAsyncThreadPreconnect(void* args_) {
struct ncclAsyncArgs* args = (struct ncclAsyncArgs*)args_;
CUDACHECKTHREAD(cudaSetDevice(args->coll.comm->cudaDev));
for (int c=0; c<args->coll.comm->p2pnChannels; c++) {
struct ncclComm* comm = args->coll.comm;
struct ncclChannel* channel = comm->channels+c;
struct ncclP2PConnect* connect = &comm->p2plist.connect;
NCCLCHECKTHREAD(ncclTransportP2pSetup(comm, NULL, channel, connect->nrecv[c], connect->recv+c*comm->nRanks, connect->nsend[c], connect->send+c*comm->nRanks));
connect->nrecv[c] = 0;
connect->nsend[c] = 0;
}
return args;
}
Then all ncclSend and ncclRecv tasks are distributed to various channels. For each AsyncArgs, iterate through each delta to get who to send to (to) and receive from (from), then use p2pnChannelsPerPeer channels to send and receive in parallel, with each channel handling sendbytes / p2pnChannelsPerPeer size. In the above example, rank0 (first AsyncArgs) will execute scheduleSendRecv twice, first with from=to=0, second with from=to=1.
ncclResult_t ncclGroupEnd() {
...
for (int i=0; i<ncclGroupIndex; i++) {
struct ncclAsyncArgs* args = ncclGroupArgs+i;
if (args->funcType == ASYNC_FUNC_COLL) {
struct ncclComm* comm = args->coll.comm;
int rank = comm->rank;
int nRanks = comm->nRanks;
struct ncclP2Plist* p2plist = &args->coll.comm->p2plist;
if (p2plist->count) {
for (int delta=0; delta<nRanks; delta++) {
uint32_t from = (rank+nRanks-delta)%nRanks;
uint32_t to = (rank+delta)%nRanks;
// Compute how much to split operations
// Natural step size matching buffer steps.
ssize_t stepSize = 4*comm->buffSizes[NCCL_PROTO_SIMPLE] / NCCL_STEPS;
// Split each operation on p2pnChannelsPerPeer max.
ssize_t recvChunkSize = DIVUP(p2plist->peerlist[from].recvbytes, comm->p2pnChannelsPerPeer);
ssize_t sendChunkSize = DIVUP(p2plist->peerlist[to].sendbytes, comm->p2pnChannelsPerPeer);
recvChunkSize = std::max((ssize_t)1, DIVUP(recvChunkSize, stepSize)) * stepSize;
sendChunkSize = std::max((ssize_t)1, DIVUP(sendChunkSize, stepSize)) * stepSize;
ssize_t sendOffset = 0;
ssize_t recvOffset = 0;
int remaining = 1;
int chunk = 0;
while (remaining) {
int channelId = (delta+comm->p2pChannels[chunk%comm->p2pnChannelsPerPeer]) % comm->p2pnChannels;
remaining = 0;
ssize_t recvbytes = p2plist->peerlist[from].recvbytes-recvOffset;
ssize_t sendbytes = p2plist->peerlist[to].sendbytes-sendOffset;
if (recvbytes > recvChunkSize) { remaining = 1; recvbytes = recvChunkSize; } else p2plist->peerlist[from].recvbytes = -1;
if (sendbytes > sendChunkSize) { remaining = 1; sendbytes = sendChunkSize; } else p2plist->peerlist[to].sendbytes = -1;
if (sendbytes >= 0 || recvbytes >= 0) {
NCCLCHECKGOTO(scheduleSendRecv(comm, delta, channelId,
recvbytes, ((char*)(p2plist->peerlist[from].recvbuff)) + recvOffset,
sendbytes, ((const char*)(p2plist->peerlist[to].sendbuff)) + sendOffset), ret, end);
}
recvOffset += recvChunkSize;
sendOffset += sendChunkSize;
chunk++;
}
}
p2plist->count = 0;
}
}
}
...
}
Then generate an ncclInfo, recording channelId, sendbuff, recvbuff and other information, and execute ncclSaveKernel.
static ncclResult_t scheduleSendRecv(struct ncclComm* comm, int delta, int channelId, ssize_t recvbytes, void* recvbuff, ssize_t sendbytes, const void* sendbuff) {
struct ncclInfo info = { ncclCollSendRecv, "SendRecv",
sendbuff, recvbuff, (size_t)std::max<ssize_t>(sendbytes,recvbytes), ncclInt8, ncclSum, -1, comm, comm->userStream, /* Args */
1, 1 };
info.delta = delta;
info.channelId = channelId;
info.sendbytes = sendbytes;
info.recvbytes = recvbytes;
if (delta == 0 && sendbytes != recvbytes) return ncclInvalidUsage;
NCCLCHECK(ncclSaveKernel(&info));
return ncclSuccess;
}
Then set kernel-related parameters through ncclSaveKernel, namely ncclColl. The args type in Figure 1 is ncclColl. As mentioned in section 7, collectives (ncclColl array) are allocated for each channel during initChannel.
struct ncclColl {
union {
struct {
struct CollectiveArgs args;
uint16_t funcIndex; // 应该使用哪个kernel
uint16_t nextIndex; // 下一个ncclColl
uint8_t active; // 当前ncclColl是否被占用
};
int data[0x10];
};
};
struct CollectiveArgs {
struct ncclDevComm* comm;
// local and remote input, output, and buffer
const void * sendbuff;
void * recvbuff;
// Op-specific fields. Make sure the common part stays the
// same on all structs of the union
union {
struct {
uint16_t nThreads;
} common;
struct {
uint16_t nThreads;
uint8_t bid;
uint8_t nChannels;
uint32_t root;
size_t count;
size_t lastChunkSize;
} coll;
struct {
uint16_t nThreads;
uint16_t unused;
int32_t delta;
size_t sendCount;
size_t recvCount;
} p2p;
};
};
computeColl initializes ncclColl coll using ncclInfo, including sendbuf, recvbuf, comm, etc., then sets myParams’ blockDim. Find the channel using channelId from info, try to add the current coll to channel’s collectives. collFifoTail is the tail of collectives, corresponding ncclColl is c. First wait for c’s active until not occupied, then copy coll to c, set active to 1, increment channel’s collcount, point collFifoTail to next ncclColl, set c’s nextIndex to collFifoTail. Note that ncclProxySaveP2p has no effect in current scenario, so it’s omitted.
ncclResult_t ncclSaveKernel(struct ncclInfo* info) {
if (info->comm->nRanks == 1 && info->coll != ncclCollSendRecv) {
if (info->sendbuff != info->recvbuff)
CUDACHECK(cudaMemcpyAsync(info->recvbuff, info->sendbuff, info->nBytes, cudaMemcpyDeviceToDevice, info->stream));
return ncclSuccess;
}
struct ncclColl coll;
struct ncclProxyArgs proxyArgs;
memset(&proxyArgs, 0, sizeof(struct ncclProxyArgs));
NCCLCHECK(computeColl(info, &coll, &proxyArgs));
info->comm->myParams->blockDim.x = std::max<unsigned>(info->comm->myParams->blockDim.x, info->nThreads);
int nChannels = info->coll == ncclCollSendRecv ? 1 : coll.args.coll.nChannels;
int nSubChannels = (info->pattern == ncclPatternCollTreeUp || info->pattern == ncclPatternCollTreeDown) ? 2 : 1;
for (int bid=0; bid<nChannels*nSubChannels; bid++) {
int channelId = (info->coll == ncclCollSendRecv) ? info->channelId :
info->comm->myParams->gridDim.x % info->comm->nChannels;
struct ncclChannel* channel = info->comm->channels+channelId;
if (channel->collCount == NCCL_MAX_OPS) {
WARN("Too many aggregated operations on channel %d (%d max)", channel->id, NCCL_MAX_OPS);
return ncclInvalidUsage;
}
// Proxy
proxyArgs.channel = channel;
// Adjust pattern for CollNet based on channel index
if (nSubChannels == 2) {
info->pattern = (channelId < info->comm->nChannels/nSubChannels) ? ncclPatternCollTreeUp : ncclPatternCollTreeDown;
}
if (info->coll == ncclCollSendRecv) {
info->comm->myParams->gridDim.x = std::max<unsigned>(info->comm->myParams->gridDim.x, channelId+1);
NCCLCHECK(ncclProxySaveP2p(info, channel));
} else {
NCCLCHECK(ncclProxySaveColl(&proxyArgs, info->pattern, info->root, info->comm->nRanks));
}
info->comm->myParams->gridDim.x++;
int opIndex = channel->collFifoTail;
struct ncclColl* c = channel->collectives+opIndex;
volatile uint8_t* activePtr = (volatile uint8_t*)&c->active;
while (activePtr[0] != 0) sched_yield();
memcpy(c, &coll, sizeof(struct ncclColl));
if (info->coll != ncclCollSendRecv) c->args.coll.bid = bid % coll.args.coll.nChannels;
c->active = 1;
opIndex = (opIndex+1)%NCCL_MAX_OPS;
c->nextIndex = opIndex;
channel->collFifoTail = opIndex;
channel->collCount++;
}
info->comm->opCount++;
return ncclSuccess;
}
At this point scheduleSendRecv execution is complete. Back to ncclGroupEnd, it will execute ncclBarrierEnqueue for each AsyncArgs
ncclResult_t ncclGroupEnd() {
...
for (int i=0; i<ncclGroupIndex; i++) {
struct ncclAsyncArgs* args = ncclGroupArgs+i;
if (args->funcType == ASYNC_FUNC_COLL) {
if (args->coll.comm->userStream == NULL)
CUDACHECKGOTO(cudaSetDevice(args->coll.comm->cudaDev), ret, end);
NCCLCHECKGOTO(ncclBarrierEnqueue(args->coll.comm), ret, end);
}
}
...
}
First, myParams is set through setupLaunch.
ncclResult_t ncclBarrierEnqueue(struct ncclComm* comm) {
struct cudaLaunchParams* params = comm->myParams;
if (params->gridDim.x == 0) return ncclSuccess;
NCCLCHECK(setupLaunch(comm, params));
...
return ncclSuccess;
}
We mentioned during channel search that one channel corresponds to one block. Here in setupLaunch we can see it iterates through p2p channels, setting gridDim.x to the number of channels. However, since some channels don’t have p2p operations, a fake ncclColl needs to be created for these empty channels, setting delta to -1 to indicate no p2p operations, and setting funcIndex, comm and other information. Then set the last ncclColl’s active to 2 indicating it’s the last ncclColl. Then copy the first ncclColl of the first channel to comm->args and set func in myParam, completing kernel parameter setup.
ncclResult_t setupLaunch(struct ncclComm* comm, struct cudaLaunchParams* params) {
// Only launch blocks where we have work to do.
for (int c=0; c<comm->p2pnChannels; c++) {
if (comm->channels[c].collCount) params->gridDim.x = c+1;
}
// Set active = 2 for the last operation and add a no-op on empty channels (p2p case).
for (int c=0; c<params->gridDim.x; c++) {
struct ncclChannel* channel = comm->channels+c;
if (channel->collCount == 0) {
int opIndex = channel->collFifoTail;
struct ncclColl* c = channel->collectives+opIndex;
volatile uint8_t* activePtr = (volatile uint8_t*)&c->active;
while (activePtr[0] != 0) sched_yield();
c->args.p2p.delta = -1; // no-op
c->funcIndex = FUNC_INDEX_P2P;
c->args.comm = comm->devComm;
c->active = 1;
opIndex = (opIndex+1)%NCCL_MAX_OPS;
c->nextIndex = opIndex;
channel->collFifoTail = opIndex;
channel->collCount++;
}
channel->collectives[(channel->collStart+channel->collCount-1)%NCCL_MAX_OPS].active = 2;
}
// Find the first operation, choose the kernel accordingly and pass it
// as the first argument.
struct ncclColl* coll = comm->channels[0].collectives+comm->channels[0].collStart;
memcpy(&comm->args, coll, sizeof(struct ncclColl));
// As we pass that coll directly, we can free it immediately.
coll->active = 0;
params->func = ncclKerns[coll->funcIndex];
return ncclSuccess;
}
Then back to ncclBarrierEnqueue, ncclCpuBarrierIn will be executed.
ncclResult_t ncclBarrierEnqueue(struct ncclComm* comm) {
...
if (comm->launchMode == ncclComm::GROUP) {
int isLast = 0;
NCCLCHECK(ncclCpuBarrierIn(comm, &isLast));
if (isLast) {
// I'm the last. Launch all operations.
NCCLCHECK(ncclLaunchCooperativeKernelMultiDevice(comm->intraParams, comm->intraCudaDevs, comm->intraRanks, *comm->intraCGMode));
NCCLCHECK(ncclCpuBarrierLast(comm));
}
}
return ncclSuccess;
}
Here cas operations are performed on intraBarrier until isLast is set to 1 on the intraRanks-th execution of ncclBarrierEnqueue. In other words, kernel will only start on execution of the last AsyncArgs.
ncclResult_t ncclCpuBarrierIn(struct ncclComm* comm, int* isLast) {
volatile int* ptr = (volatile int*)(comm->intraBarrier+comm->intraPhase);
int val = *ptr;
bool done = false;
while (done == false) {
if (val >= comm->intraRanks) {
WARN("Trying to launch too many collectives");
return ncclInvalidUsage;
}
if (val+1 == comm->intraRanks) {
// Reset the barrier.
comm->intraBarrier[comm->intraPhase^1] = 0;
*isLast = 1;
return ncclSuccess;
}
done = __sync_bool_compare_and_swap(ptr, val, val+1);
val++;
}
*isLast = 0;
return ncclSuccess;
}
Then launch kernels on multiple devices at once through cudaLaunchCooperativeKernelMultiDevice.
ncclResult_t ncclLaunchCooperativeKernelMultiDevice(struct cudaLaunchParams *paramsList, int* cudaDevs, int numDevices, int cgMode) {
#if CUDART_VERSION >= 9000
if (cgMode & 0x01) {
CUDACHECK(cudaLaunchCooperativeKernelMultiDevice(paramsList, numDevices,
// These flags are to reduce the latency of using this API
cudaCooperativeLaunchMultiDeviceNoPreSync|cudaCooperativeLaunchMultiDeviceNoPostSync));
return ncclSuccess;
}
#endif
int savedDev;
CUDACHECK(cudaGetDevice(&savedDev));
for (int i = 0; i < numDevices; i++) {
struct cudaLaunchParams* params = paramsList+i;
CUDACHECK(cudaSetDevice(cudaDevs[i]));
CUDACHECK(cudaLaunchKernel(params->func, params->gridDim, params->blockDim, params->args, params->sharedMem, params->stream));
}
CUDACHECK(cudaSetDevice(savedDev));
return ncclSuccess;
}
Kernel Execution#
ncclKerns is defined as follows, we use the first one, which is ncclSendRecvKernel_copy_i8
#define NCCL_KERN_NAME(coll, op, dtype) \
coll##Kernel_##op##_##dtype
static void* const ncclKerns[1+NCCL_NUM_FUNCTIONS*ncclNumOps*ncclNumTypes*NCCL_NUM_ALGORITHMS*NCCL_NUM_PROTOCOLS] = {
(void*)NCCL_KERN_NAME(ncclSendRecv, copy, i8),
NCCL_FUNCS2B(ncclBroadcast),
NCCL_FUNCS2A(ncclReduce),
NCCL_FUNCS2B(ncclAllGather),
NCCL_FUNCS2A(ncclReduceScatter),
NCCL_FUNCS2A(ncclAllReduce)
};
The first ncclColl is passed through parameters to the kernel, so block 0’s c can be directly set to firstcoll. Other blocks need to use load_coll for copying. After loading, the host’s ncclColl active can be set to 0.
static __device__ void load_parallel(void* dst, void* src, size_t size, int tid) {
int* d = (int*)dst;
int* s = (int*)src;
for (int o = tid; o < (size/sizeof(int)); o += blockDim.x) d[o] = s[o];
}
static __device__ void load_coll(struct ncclColl* localColl, struct ncclColl* hostColl, int tid, struct ncclDevComm* comm) {
// Check whether the last operation was aborted and make sure all threads exit
int abort = tid == 0 ? *(comm->abortFlag) : 0;
exitIfAbortBarrier(abort);
load_parallel(localColl, hostColl, sizeof(struct ncclColl), tid);
__syncthreads();
if (tid == 0) hostColl->active = 0;
}
Then begins the while loop to iterate through each ncclColl until ncclColl’s active becomes 2, indicating it’s the last one, at which point the loop exits.
#define IMPL_COLL_KERN(coll, op, ncclFunc, dtype, ctype, fIndex) \
__global__ void NCCL_KERN_NAME(coll, op, dtype)(struct ncclColl firstColl) { \
int tid = threadIdx.x; \
int bid = blockIdx.x; \
__shared__ volatile uint64_t shmem[NCCL_LL128_SHMEM_SIZE]; \
ncclShmem = shmem; \
__shared__ struct ncclColl localColl; \
\
struct ncclDevComm* comm = firstColl.args.comm; \
struct ncclChannel* channel = comm->channels+bid; \
struct ncclColl* c; \
if (bid == 0) { \
/* To optimize for latency, (only) the first operation is passed as argument.*/ \
c = &firstColl; \
} else { \
c = &localColl; \
load_coll(c, channel->collectives+channel->collFifoHead, tid, comm); \
} \
while (1) { \
if (tid < c->args.common.nThreads) { \
if (c->funcIndex == fIndex) { \
coll##Kernel<COLL_UNROLL, ncclFunc<ctype>, ctype>(&c->args); \
} else { \
ncclFuncs[c->funcIndex](&c->args); \
} \
} \
int nextIndex = c->nextIndex; \
if (tid == 0) channel->collFifoHead = nextIndex; \
\
if (c->active == 2) { \
return; \
} \
\
/* Load next collective operation*/ \
c = &localColl; /* for bid 0 */ \
load_coll(c, channel->collectives+nextIndex, tid, comm); \
} \
}
For each ncclColl, ncclSendRecvKernel<4, FuncSum<int8_t>, int8_t> is executed. Let’s first look at thread organization within a block. Assuming args->p2p.nThreads is 320, 160 threads are used for send and 160 for recv. Further, of the 160 threads, 128 are used for actual data transfer, and the remaining 32 threads (one warp) are used for synchronization.
First, calculate nthreads, which is 256 here. Get sendbuff and recvbuff from args. If delta is negative, it means this channel has no p2p operation and is fake, so return directly. If delta is 0, it’s send/recv between the same card, so execute data copy directly through ReduceOrCopyMulti, with blockSize being the copy length each time.
template<int UNROLL, class FUNC, typename T>
__device__ void ncclSendRecvKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
const int nthreads = args->p2p.nThreads-2*WARP_SIZE;
// Compute pointers
const T* sendbuff = (const T*)args->sendbuff;
T* recvbuff = (T*)args->recvbuff;
if (args->p2p.delta < 0 ) return; // No-op
if (args->p2p.delta == 0) {
if (tid < nthreads && sendbuff != recvbuff) {
// local copy : ReduceOrCopyMulti takes an int as number of elements,
// so we split it in blocks of 1G elements.
int blockSize = 1<<30;
for (size_t offset=0; offset<args->p2p.sendCount; offset += blockSize) {
size_t remaining = args->p2p.sendCount - offset;
if (remaining < blockSize) blockSize = remaining;
ReduceOrCopyMulti<UNROLL, FUNC, T, 1, 1, 1, 1>(tid, nthreads, 1, &sendbuff, 1, &recvbuff, blockSize);
sendbuff += blockSize; recvbuff += blockSize;
}
}
return;
}
...
}
Then let’s look at ReduceOrCopyMulti, which handles actual data copying, reducing nsrcs source arrays through FUNC and copying to ndsts destination arrays, each array having length N. ReduceOrCopyMulti attempts to use 128-bit vectorized load/store to improve bandwidth utilization and reduce instruction count for better performance, but this requires aligned data (16 bytes). If src and dst aren’t 16-byte aligned but have the same remainder when divided by 16, then non-vectorized instructions can be used to copy the unaligned front portion, after which vectorized instructions can be used. If the remainders are different, only non-vectorized instructions can be used. The process has three steps: handle the unaligned front portion, handle the aligned middle portion, and handle the tail portion.
ptrAlign128 takes modulo 16. First, it uses XOR to check if srcs and dsts alignment is consistent. If inconsistent, Npreamble = N, and non-vectorized instructions must be used for all copying. Otherwise, Npreamble = (alignof(Pack128) - align) % alignof(Pack128), which is the unaligned front portion.
typedef ulong2 Pack128;
template <typename T>
__device__ int ptrAlign128(T* ptr) { return (uint64_t)ptr % alignof(Pack128); }
template<int UNROLL, class FUNC, typename T, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
__device__ __forceinline__ void ReduceOrCopyMulti(const int tid, const int nthreads,
int nsrcs, const T* srcs[MAXSRCS], int ndsts, T* dsts[MAXDSTS],
int N) {
int Nrem = N;
if (Nrem <= 0) return;
int alignDiff = 0;
int align = ptrAlign128(srcs[0]);
#pragma unroll
for (int i=1; i<MINSRCS; i++) alignDiff |= (align ^ ptrAlign128(srcs[i]));
for (int i=MINSRCS; i<MAXSRCS && i<nsrcs; i++) alignDiff |= (align ^ ptrAlign128(srcs[i]));
#pragma unroll
for (int i=0; i<MINDSTS; i++) alignDiff |= (align ^ ptrAlign128(dsts[i]));
for (int i=MINDSTS; i<MAXDSTS && i<ndsts; i++) alignDiff |= (align ^ ptrAlign128(dsts[i]));
int Npreamble = alignDiff ? Nrem :
N < alignof(Pack128) ? N :
(alignof(Pack128) - align) % alignof(Pack128);
// stage 1: preamble: handle any elements up to the point of everything coming
// into alignment
if (Npreamble) {
ReduceCopyMulti<FUNC, T, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(tid, nthreads, nsrcs, srcs, ndsts, dsts, 0, Npreamble);
Nrem -= Npreamble;
if (Nrem == 0) return;
}
...
}
For the unaligned portion, directly use ReduceCopyMulti with non-vectorized instructions. 128 threads read consecutive 128 int8_t from src and store to dst, executing in loops. The access pattern is shown in the following figure.
Figure 2
template<typename T> inline __device__
T vFetch(const volatile T* ptr) {
return *ptr;
}
template<typename T> inline __device__
void vStore(volatile T* ptr, const T val) {
*ptr = val;
}
template<class FUNC, typename T, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
__device__ __forceinline__ void ReduceCopyMulti(const int tid, const int nthreads,
int nsrcs, const T* srcs[MAXSRCS], int ndsts, T* dsts[MAXDSTS],
const int offset, const int N) {
for (int idx = offset+tid; idx < offset+N; idx += nthreads) {
T val = vFetch(srcs[0]+idx);
#pragma unroll
for (int i=1; i<MINSRCS; i++) val = FUNC()(val, vFetch(srcs[i]+idx));
#pragma unroll 1
for (int i=MINSRCS; i<MAXSRCS && i<nsrcs; i++) val = FUNC()(val, vFetch(srcs[i]+idx));
#pragma unroll
for (int i=0; i<MINDSTS; i++) vStore(dsts[i]+idx, val);
#pragma unroll 1
for (int i=MINDSTS; i<MAXDSTS && i<ndsts; i++) vStore(dsts[i]+idx, val);
}
}
Then starts step two, handling aligned data. This is done in two parts: first, for data that’s divisible by packFactor * AUTOUNROLL * WARP_SIZE, execute ReduceCopy128bMulti with AUTOUNROLL enabled. For remaining data, set AUTOUNROLL to 1 and execute ReduceCopy128bMulti.
Finally, for data less than packFactor (can’t form 128 bits), use ReduceCopyMulti for non-vectorized copying.
template<int UNROLL, class FUNC, typename T, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
__device__ __forceinline__ void ReduceOrCopyMulti(const int tid, const int nthreads,
int nsrcs, const T* srcs[MAXSRCS], int ndsts, T* dsts[MAXDSTS],
int N) {
...
int offset = Npreamble;
// stage 2: fast path: use 128b loads/stores to do the bulk of the work,
// assuming the pointers we have are all 128-bit alignable.
int w = tid / WARP_SIZE; // Warp number
int nw = nthreads / WARP_SIZE; // Number of warps
int t = tid % WARP_SIZE; // Thread (inside the warp)
const int packFactor = sizeof(Pack128) / sizeof(T);
// stage 2a: main loop
int Npack2a = (Nrem / (packFactor * AUTOUNROLL * WARP_SIZE))
* (AUTOUNROLL * WARP_SIZE); // round down
int Nelem2a = Npack2a * packFactor;
ReduceCopy128bMulti<FUNC, T, AUTOUNROLL, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack2a);
Nrem -= Nelem2a;
if (Nrem == 0) return;
offset += Nelem2a;
// stage 2b: slightly less optimized for section when we don't have full
// unrolling
int Npack2b = Nrem / packFactor;
int Nelem2b = Npack2b * packFactor;
ReduceCopy128bMulti<FUNC, T, 1, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(w, nw, t, nsrcs, srcs, ndsts, dsts, offset, Npack2b);
Nrem -= Nelem2b;
if (Nrem == 0) return;
offset += Nelem2b;
// stage 2c: tail
ReduceCopyMulti<FUNC, T, MINSRCS, MAXSRCS, MINDSTS, MAXDSTS>(tid, nthreads, nsrcs, srcs, ndsts, dsts, offset, Nrem);
}
Then let’s look at how ReduceCopy128bMulti uses vectorized instructions for copying. The load/store here uses inline PTX, though it seems unnecessary. Fetch128 loads a ulong2 from position p into register variable v. There’s a variable UNROLL here - one warp processes UNROLL * WARP_SIZE consecutive ulong2s at once, similar to loop unrolling. When UNROLL is 4, the memory access pattern is shown in the following figure. For example, thread 0 will read the first ulong2 from 4 yellow boxes into register variable vals, then write to dst.
Figure 3
Specifically, when UNROLL is 1, the access pattern is similar to ReduceCopyMulti - 128 threads process 128 consecutive ulong2s, then loop to process the next 128 ulong2s.
inline __device__ void Fetch128(Pack128& v, const Pack128* p) {
asm volatile("ld.volatile.global.v2.u64 {%0,%1}, [%2];" : "=l"(v.x), "=l"(v.y) : "l"(p) : "memory");
}
inline __device__ void Store128(Pack128* p, Pack128& v) {
asm volatile("st.volatile.global.v2.u64 [%0], {%1,%2};" :: "l"(p), "l"(v.x), "l"(v.y) : "memory");
}
template<class FUNC, typename T>
struct MULTI128 {
__device__ void operator()(Pack128& x, Pack128& y) {
x.x = MULTI<FUNC, T>()(x.x, y.x);
x.y = MULTI<FUNC, T>()(x.y, y.y);
}
};
template<class FUNC, typename T, int UNROLL, int MINSRCS, int MAXSRCS, int MINDSTS, int MAXDSTS>
__device__ __forceinline__ void ReduceCopy128bMulti( const int w, const int nw, const int t,
int nsrcs, const T* s[MAXSRCS], int ndsts, T* d[MAXDSTS],
const int elemOffset, const int Npack) {
const int inc = nw * UNROLL * WARP_SIZE;
int offset = w * UNROLL * WARP_SIZE + t;
const Pack128* srcs[MAXSRCS];
for (int i=0; i<MAXSRCS; i++) srcs[i] = ((const Pack128*)(s[i]+elemOffset))+offset;
Pack128* dsts[MAXDSTS];
for (int i=0; i<MAXDSTS; i++) dsts[i] = ((Pack128*)(d[i]+elemOffset))+offset;
while (offset < Npack) {
Pack128 vals[UNROLL];
// Load and reduce
for (int u = 0; u < UNROLL; ++u) Fetch128(vals[u], srcs[0]+u*WARP_SIZE);
for (int i=1; i<MINSRCS; i++) {
Pack128 vals2[UNROLL];
for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE);
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(vals[u], vals2[u]);
}
#pragma unroll 1
for (int i=MINSRCS; i<MAXSRCS && i<nsrcs; i++) {
Pack128 vals2[UNROLL];
for (int u = 0; u < UNROLL; ++u) Fetch128(vals2[u], srcs[i]+u*WARP_SIZE);
for (int u = 0; u < UNROLL; ++u) MULTI128<FUNC, T>()(vals[u], vals2[u]);
}
// Store
for (int i = 0; i < MINDSTS; i++) {
for (int u = 0; u < UNROLL; ++u) Store128(dsts[i]+u*WARP_SIZE, vals[u]);
}
#pragma unroll 1
for (int i=MINDSTS; i<MAXDSTS && i<ndsts; i++) {
for (int u = 0; u < UNROLL; ++u) Store128(dsts[i]+u*WARP_SIZE, vals[u]);
}
for (int i=0; i<MAXSRCS; i++) srcs[i] += inc;
for (int i=0; i<MAXDSTS; i++) dsts[i] += inc;
offset += inc;
}
}
This completes the send/recv within the same card. Moving on to ncclSendRecvKernel, we can see that of the 320 threads mentioned earlier, 160 are used for send and 160 for recv. Both send and recv threads instantiate a ncclPrimitives, using directSend to send data and directRecv to receive data.
template<int UNROLL, class FUNC, typename T>
__device__ void ncclSendRecvKernel(struct CollectiveArgs* args) {
const int tid = threadIdx.x;
const int nthreads = args->p2p.nThreads-2*WARP_SIZE;
// Compute pointers
const T* sendbuff = (const T*)args->sendbuff;
T* recvbuff = (T*)args->recvbuff;
...
struct ncclDevComm* comm = args->comm;
struct ncclChannel* channel = comm->channels+blockIdx.x;
const int stepSize = comm->buffSizes[NCCL_PROTO_SIMPLE]/(sizeof(T)*NCCL_STEPS)/SENDRECV_SLICEFACTOR;
int nthreadsSplit = nthreads/2;
// We set NRECV or NSEND to 2 to use different barriers in primitives for the send threads and
// receive threads, but then we define all peers to -1 since sender threads don't receive and
// receive threads don't send.
int peerNone[2] = {-1,-1};
if (tid < nthreadsSplit + WARP_SIZE ) {
const ssize_t sendSize = args->p2p.sendCount;
if (sendSize < 0) return;
int peer = (comm->rank+(int)args->p2p.delta)%comm->nRanks;
ncclPrimitives<UNROLL, 1, 1, T, 2, 1, 1, FUNC>
prims(tid, nthreadsSplit, peerNone, &peer, recvbuff, stepSize*4, channel, comm);
if (sendSize == 0) {
prims.send(sendbuff, 0);
} else for (ssize_t offset = 0; offset < sendSize; offset += stepSize) {
int realChunkSize = min(stepSize, sendSize-offset);
ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
int nelem = min(realChunkSize, sendSize-offset);
prims.directSend(sendbuff+offset, offset, nelem);
}
} else {
const ssize_t recvSize = args->p2p.recvCount;
if (recvSize < 0) return;
int peer = (comm->rank-(int)args->p2p.delta+comm->nRanks)%comm->nRanks;
ncclPrimitives<UNROLL, 1, 1, T, 1, 2, 1, FUNC>
prims(tid-nthreadsSplit-WARP_SIZE, nthreads-nthreadsSplit, &peer, peerNone, recvbuff, stepSize*4, channel, comm);
if (recvSize == 0) {
prims.recv(recvbuff, 0);
} else for (ssize_t offset = 0; offset < recvSize; offset += stepSize) {
int realChunkSize = min(stepSize, recvSize-offset);
ALIGN_SIZE(realChunkSize, nthreads*sizeof(uint64_t)/sizeof(T));
int nelem = min(realChunkSize, recvSize-offset);
prims.directRecv(recvbuff+offset, offset, nelem);
}
}
}
For better understanding, here are the template types:
/*
send:
UNROLL: 4,
SLICESPERCHUNK: 1,
SLICESTEPS: 1,
T: int8_t,
NRECV: 2,
NSEND: 1,
DIRECT: 1,
FUNC: FuncSum<int8_t>
recv:
UNROLL: 4,
SLICESPERCHUNK: 1,
SLICESTEPS: 1,
T: int8_t,
NRECV: 1,
NSEND: 2,
DIRECT: 1,
FUNC: FuncSum<int8_t>
*/
template <int UNROLL, int SLICESPERCHUNK, int SLICESTEPS, typename T, int NRECV, int NSEND, int DIRECT, class FUNC>
class ncclPrimitives {
...
}
First look at ncclPrimitives’ constructor. Here nthreads is 160 - 32 = 128, with 32 threads for synchronization. Since send’s recvPeer is -1, send won’t loadRecvConn, and recv won’t loadSendConn.
__device__ __forceinline__
ncclPrimitives(const int tid, const int nthreads, int* recvPeers, int* sendPeers, T* directBuff, int stepSize, struct ncclChannel* channel, struct ncclDevComm* comm)
: comm(comm), tid(tid), nthreads(nthreads), wid(tid%WARP_SIZE), stepSize(stepSize) {
// Make sure step is updated before we read it.
barrier();
for (int i=0; i<NRECV && recvPeers[i] >= 0; i++) loadRecvConn(&channel->devPeers[recvPeers[i]].recv.conn, i, directBuff);
for (int i=0; i<NSEND && sendPeers[i] >= 0; i++) loadSendConn(&channel->devPeers[sendPeers[i]].send.conn, i);
loadRecvSync();
loadSendSync();
}
Then begin loading recv’s ncclConnInfo, saving recvBuff and step information. Since p2pread is supported in p2p setup process, conn->direct hasn’t set NCCL_DIRECT_GPU, so it won’t enter the first if. The first thread of each warp saves ncclConnInfo, initializing recvConnTail and recvConnHead to recvStep.
__device__ __forceinline__ void loadRecvConn(struct ncclConnInfo* conn, int i, T* directBuff) {
recvBuff[i] = (const T*)conn->buffs[NCCL_PROTO_SIMPLE];
recvStep[i] = conn->step;
recvStep[i] = ROUNDUP(recvStep[i], SLICESPERCHUNK*SLICESTEPS);
recvDirectBuff[i] = NULL;
if (DIRECT && (conn->direct & NCCL_DIRECT_GPU)) {
recvDirectBuff[i] = directBuff;
if (tid == 0) *conn->ptrExchange = directBuff;
}
if (wid == i) recvConn = conn;
if (wid == i) recvConnTail = recvConnHead = recvStep[i]; // Make sure we set this after rounding up
nrecv++;
}
Then load send’s conn, save step and sendBuff. The first thread of each warp saves conn and initializes sendConnTail and sendConnHead to step.
__device__ __forceinline__ void loadSendConn(struct ncclConnInfo* conn, int i) {
sendBuff[i] = (T*)conn->buffs[NCCL_PROTO_SIMPLE];
sendStep[i] = conn->step;
sendStep[i] = ROUNDUP(sendStep[i], SLICESPERCHUNK*SLICESTEPS);
sendDirectBuff[i] = NULL;
if (DIRECT && (conn->direct & NCCL_DIRECT_GPU)) {
void* volatile* ptr = conn->ptrExchange;
while ((sendDirectBuff[i] = (T*)(*ptr)) == NULL);
barrier();
if (tid == 0) *ptr = NULL;
}
if (wid == i) sendConn = conn;
if (wid == i) sendConnTail = sendConnHead = sendStep[i]; // Make sure we set this after rounding up
nsend++;
}
The first thread of the second warp saves tail and caches tail’s value; The first thread of sync threads saves head
__device__ __forceinline__ void loadRecvSync() {
if (tid >= WARP_SIZE && tid < 2*WARP_SIZE && wid<nrecv) {
recvConnTailPtr = recvConn->tail;
recvConnTailCache = *recvConnTailPtr;
}
if (tid >= nthreads && wid < nrecv) {
recvConnHeadPtr = recvConn->head;
// Return credits in case we rounded up.
*recvConnHeadPtr = recvConnHead;
}
}
The first thread saves head and caches head’s value. fifo is used by proxy and not needed for now; The first thread in sync threads saves tail.
__device__ __forceinline__ void loadSendSync() {
if (tid < nsend) {
sendConnHeadPtr = sendConn->head;
sendConnHeadCache = *sendConnHeadPtr;
sendConnFifoPtr = sendConn->fifo;
}
if (tid >= nthreads && wid<nsend) {
sendConnTailPtr = sendConn->tail;
}
}
Now let’s look at what these variables do. In p2p transport setup stage (discussed in section 8), each rank created variables to coordinate send/receive process, as shown below. Since p2p read is supported, buff is at the sender; tail is at the receiver and shared by sender and receiver, updated by sender; head is at sender and shared by sender and receiver, updated by receiver. In ncclPrimitives’ receiver, tail is called recvConnTailPtr and head is called recvConnHeadPtr; while in sender, tail is called sendConnTailPtr and head is called sendConnHeadPtr.
Figure 4
Then let’s see how these variables coordinate the send/receive process
Figure 5
The yellow boxes in the middle are the buff shown in Figure 4. The entire buff is divided into NCCL_STEP blocks, Figure 5 only shows six blocks.
sendConnHead, sendConnTailPtr, sendStep are updated by sender, incrementing by one each send. These values are actually equal (so these variables seem somewhat redundant).
recvConnTail, recvConnHeadPtr, recvStep are updated by receiver, incrementing by one each receive. These values are actually equal.
Therefore, for receiver, as long as recvConnTail is less than recvConnTailPtr, it indicates data is available for receiving, and recvConnTail is incremented by one to indicate another block of data has been received.
inline __device__ void waitRecv() {
spins = 0;
if (recvConnTailPtr) {
while (recvConnTailCache < recvConnTail + SLICESTEPS) {
recvConnTailCache = *recvConnTailPtr;
if (checkAbort(wid, 0)) break;
}
recvConnTail += SLICESTEPS;
}
}
For sender, as long as sendConnHead is greater than sendConnenHeadPtr plus NCCL_STEP, it indicates space is available for sending, and sendConnHead is incremented by one to indicate another send has been executed.
inline __device__ void waitSend(int nbytes) {
spins = 0;
if (sendConnHeadPtr) {
while (sendConnHeadCache + NCCL_STEPS < sendConnHead + SLICESTEPS) {
sendConnHeadCache = *sendConnHeadPtr;
if (checkAbort(wid, 1)) break;
}
if (sendConnFifoPtr) {
sendConnFifoPtr[sendConnHead%NCCL_STEPS] = nbytes;
}
sendConnHead += SLICESTEPS;
}
}
Then look at the directSend process. The srcs array has only one element, srcPtr is args->sendbuff (user-provided), so srcs[0] is sendbuff; dsts array also has only one element.
__device__ __forceinline__ void
directSend(const T* src, ssize_t directOffset, int nelem) {
GenericOp<0, 1, 0, 1, 1, 0>(src, NULL, nelem, directOffset);
}
__device__ __forceinline__ void
directRecv(T* dst, ssize_t directOffset, int nelem) {
GenericOp<1, 0, 1, 0, 0, 1>(NULL, dst, nelem, directOffset);
}
/*
send:
DIRECTRECV: 0
DIRECTSEND: 1
RECV: 0
SEND: 1
SRC: 1
DST: 0
dstPtr: NULL
*/
template <int DIRECTRECV, int DIRECTSEND, int RECV, int SEND, int SRC, int DST>
inline __device__ void
GenericOp(const T* srcPtr, T* dstPtr, int nelem, ssize_t directOffset) {
int offset = 0;
int sliceSize = stepSize*SLICESTEPS;
int dataSize = max(DIVUP(nelem, 16*SLICESPERCHUNK)*16, sliceSize/32);
const T* srcs[RECV*NRECV+SRC];
srcs[0] = SRC ? srcPtr : directRecvPtr<DIRECTRECV>(0, directOffset);
if (RECV) {
if (SRC) srcs[1] = recvPtr(0);
for (int i=1; i<NRECV && i<nrecv; i++) srcs[SRC+i] = recvPtr(i);
}
T* dsts[SEND*NSEND+DST];
dsts[0] = DST ? dstPtr : directSendPtr<DIRECTSEND>(0, directOffset);
if (SEND) {
if (DST) dsts[1] = directSendPtr<DIRECTSEND>(0, directOffset);
for (int i=1; i<NSEND && i<nsend; i++) dsts[DST+i] = directSendPtr<DIRECTSEND>(i, directOffset);
}
...
}
DIRECTSEND is 1, but sendDirectBuff is NULL, so dsts equals sendPtr(i)
template <int DIRECTSEND>
inline __device__ T* directSendPtr(int i, ssize_t directOffset) {
return DIRECTSEND && sendDirectBuff[i] ? sendDirectBuff[i]+directOffset : sendPtr(i);
}
We can see sendPtr finds the next block to use in buff based on sendStep, i.e., one of the yellow boxes in Figure 5.
inline __device__ int sendOffset(int i) { return (sendStep[i]%NCCL_STEPS)*stepSize; }
inline __device__ T* sendPtr(int i) { return ((T*)sendBuff[i])+sendOffset(i); }
Before looking at actual data sending, let’s look at several sync functions. barrier() synchronizes all send or receive threads, subBarrier() synchronizes data transfer threads (excluding sync threads) within send/receive threads, essentially synchronizing different thread groups through different barriers.
inline __device__ void barrier() {
if (NSEND>NRECV) {
asm volatile ("bar.sync 1, %0;" :: "r"(nthreads+WARP_SIZE));
} else {
asm volatile ("bar.sync 2, %0;" :: "r"(nthreads+WARP_SIZE));
}
}
inline __device__ void subBarrier() {
if (NSEND>NRECV) {
asm volatile ("bar.sync 3, %0;" :: "r"(nthreads));
} else {
asm volatile ("bar.sync 4, %0;" :: "r"(nthreads));
}
}
Continuing, for send operations, if not sync thread, need to execute waitSend operation above until sending is possible. Since only the first thread executes waitSend, other threads need to wait for the first thread through subBarrier, otherwise data overwrite might occur if buff is full but sending continues. Then use ReduceOrCopyMulti to copy data from src to dst, which we’ve covered earlier. The following barrier ensures queue pointer information is updated only after data sending completes, preventing situations where queue pointers are updated but data copying hasn’t finished. Then update step through incSend. For sync threads, execute __threadfence_system before updating tail pointer through postSend. This fence ensures other threads see correct data in buff when they see tail pointer updates. System-level fence is needed for inter-machine communication scenarios involving CPU threads, requiring system-level memory barriers to ensure network communication correctness. Since __threadfence_system is time-consuming, a separate warp is introduced for synchronization to improve performance. Since postSend and memory barrier execution might be done by different threads, __syncwarp is needed to synchronize the current warp.
template <int DIRECTRECV, int DIRECTSEND, int RECV, int SEND, int SRC, int DST>
inline __device__ void
GenericOp(const T* srcPtr, T* dstPtr, int nelem, ssize_t directOffset) {
...
bool syncThread = tid >= nthreads;
#pragma unroll
for (int slice=0; slice<SLICESPERCHUNK; ++slice) {
int realSize = max(0, min(dataSize, nelem-offset));
if (!syncThread) {
if (SEND) waitSend(realSize*sizeof(T));
if (RECV) waitRecv();
if (realSize > 0) {
subBarrier();
if (DIRECTRECV && recvDirectBuff[0]) {
// We can only have one direct receive. Since srcs[0] == dstPtr+offset, skip one copy
if (SEND) {
ReduceOrCopyMulti<UNROLL, FUNC, T, 1, 1, 1, NSEND>(tid, nthreads, 1, srcs, nsend, dsts+1, realSize);
}
} else {
ReduceOrCopyMulti<UNROLL, FUNC, T, RECV+SRC, RECV*NRECV+SRC, SEND+DST, SEND*NSEND+DST>(tid, nthreads, RECV*nrecv+SRC, srcs, SEND*nsend+DST, dsts, realSize);
}
}
}
barrier();
FOR_SEND(incSend);
FOR_RECV(incRecv);
if (syncThread) {
if (SEND) {
if (realSize > 0 && wid == 0) __threadfence_system();
__syncwarp();
postSend();
}
if (RECV) postRecv();
}
srcs[0] += SRC ? realSize : directRecvInc<DIRECTRECV>(0, realSize, sliceSize);
for (int i=1-SRC; i<RECV*NRECV; i++) srcs[SRC+i] += sliceSize;
dsts[0] += DST ? realSize : directSendInc<DIRECTSEND>(0, realSize, sliceSize);
for (int i=1-DST; i<SEND*NSEND; i++) dsts[DST+i] += directSendInc<DIRECTSEND>(i, realSize, sliceSize);
offset += realSize;
}
}
This basically completes the ncclSend/ncclRecv process within a single machine, mainly in two steps: first record user operations through peerlist and generate kernel parameters based on records, then launch kernel to execute copying. For different cards, send copies data from user-specified sendbuff to nccl p2p transport’s buff, recv copies data from buff to user-specified recvbuff. Buff here acts as a fifo, with nccl coordinating send and receive processes through head and tail pointers. For same card, kernel directly copies data from sendbuff to recvbuff.