bootstrapInit()利用已知的根节(rank0)网络地址,建立一个环形网络,allgather获取所有rank的信息。
视频教程:1.4 NCCL源码解读bootstrap网络连接建立bootstrapInit()引导网络_哔哩哔哩_bilibili
核心逻辑:
1、函数输入ncclUniqueId,从而获得ncclUniqueId中包含的rank0的网络地址,每个rank上都有rank0的网络地址;
2、所有rank根据rank0的网络地址,建立socket并向rank0发送自己的网络地址,rank0上现在就有所有rank的网络地址了;
3、rank0告诉每个rank它的下一个节点网络地址,完成环形网络建立;
4、AllGather
全局收集所有节点的网络地址;
注:ncclUniqueId就是前面课程所说的,在rank0上产生,并MPI广播给所有rank,UniqueId由两部分组成,前半部分是随机数,后半部分是rank0的网络地址,因此所有rank都知道rank0的网络地址,都可以和rank0通信数据。具体细节可参考前面的课程:
NCCL源码详解1:NCCL官网使用/调用案例 Example : One Device per Process or Thread包含视频教程-CSDN博客
NCCL源码详解2:通信初始化如何获取唯一ID UniqueId,ncclGetUniqueId()中ncclInit()、bootstrapGetUniqueId()包含视频教程-CSDN博客
图示:
爱串门的小马驹太牛皮了,居然有图示,我都爱死我自己了。
四张图分别对应上面核心逻辑的四步:
1、函数输入ncclUniqueId(包含的rank0的网络地址),每个rank上都有rank0的网络地址,当然也有自己的网络地址。
2、所有rank根据rank0的网络地址,建立socket并向rank0发送自己的网络地址,rank0上现在就有所有rank的网络地址了;
3、rank0告诉每个rank它的下一个节点网络地址,
然后完成环形网络建立;
4、AllGather
全局收集所有节点的网络地址;
疑问解答:
有的小可爱可能会问:不建立环形网络不也行么,全部通过rank0来传递信息,也能获取所有rank的信息啊!
哈哈哈,小可爱很聪明啊,确实可以。只是rank0的通信负担太重,会存在性能瓶颈。在bootstrapInit()的源码中也可以看出这一点,当节点数大于128时,延迟到根节点的连接。
源码速递:
源码位置:nccl-master\src\bootstrap.cc
///
//1、函数的输入handle就是UniqueID,被强制转化欸ncclBootstrapHandle,包含rank0的网络地址
ncclResult_t bootstrapInit(struct ncclBootstrapHandle* handle, struct ncclComm* comm) {
// 获取当前节点的排名
int rank = comm->rank;
// 获取参与节点的数量
int nranks = comm->nRanks;
// 分配内存并初始化bootstrapState结构体,用于管理启动阶段的状态
struct bootstrapState* state;
NCCLCHECK(ncclCalloc(&state, 1));
state->rank = rank; // 设置当前节点的排名
state->nranks = nranks; // 设置参与节点的数量
state->abortFlag = comm->abortFlag; // 设置是否应中止通信的标志
// 将bootstrapState指针赋予comm结构体
comm->bootstrap = state;
// 设置魔术数字,用于校验
comm->magic = state->magic = handle->magic;
// 记录日志,显示当前节点的排名和参与节点的数量
TRACE(NCCL_INIT, "rank %d nranks %d", rank, nranks);
// 为当前节点准备发送给其他节点的信息
struct extInfo info = { 0 };
info.rank = rank; // 设置当前节点的排名
info.nranks = nranks; // 设置参与节点的数量
// 创建一个监听套接字,允许其他节点联系当前节点
NCCLCHECK(ncclSocketInit(&state->listenSock, &bootstrapNetIfAddr, comm->magic, ncclSocketTypeBootstrap, comm->abortFlag)); // 初始化监听套接字
NCCLCHECK(ncclSocketListen(&state->listenSock)); // 设置监听状态
NCCLCHECK(ncclSocketGetAddr(&state->listenSock, &info.extAddressListen)); // 获取监听套接字的地址
// 创建另一个监听套接字,允许根节点联系当前节点
NCCLCHECK(ncclSocketInit(&listenSockRoot, &bootstrapNetIfAddr, comm->magic, ncclSocketTypeBootstrap, comm->abortFlag)); // 初始化监听套接字
NCCLCHECK(ncclSocketListen(&listenSockRoot)); // 设置监听状态
NCCLCHECK(ncclSocketGetAddr(&listenSockRoot, &info.extAddressListenRoot)); // 获取监听套接字的地址
// 如果参与节点的数量大于128,则延迟连接到根节点,以减轻根节点的负载
if (nranks > 128) {
long msec = rank; // 计算延迟时间
struct timespec tv; // 定义时间戳结构体
tv.tv_sec = msec / 1000; // 秒部分
tv.tv_nsec = 1000000 * (msec % 1000); // 毫秒部分
TRACE(NCCL_INIT, "rank %d delaying connection to root by %ld msec", rank, msec); // 记录日志,显示延迟时间
(void) nanosleep(&tv, NULL); // 延迟指定时间
}
//
2、所有根据rank0的网络地址,建立socket并向rank0发送自己的网络地址;
// 发送当前节点的信息给根节点
NCCLCHECK(ncclSocketInit(&sock, &handle->addr, comm->magic, ncclSocketTypeBootstrap, comm->abortFlag)); // 初始化套接字
NCCLCHECK(ncclSocketConnect(&sock)); // 连接到根节点
NCCLCHECK(bootstrapNetSend(&sock, &info, sizeof(info))); // 发送信息
NCCLCHECK(ncclSocketClose(&sock)); // 关闭套接字
///
//3、rank0告诉每个rank它的下一个节点网络地址,完成环形网络建立;
// 从根节点接收下一个节点在启动环中的信息
NCCLCHECK(ncclSocketInit(&sock)); // 初始化套接字
NCCLCHECK(ncclSocketAccept(&sock, &listenSockRoot)); // 接受来自根节点的连接请求
NCCLCHECK(bootstrapNetRecv(&sock, &nextAddr, sizeof(union ncclSocketAddress))); // 接收信息
NCCLCHECK(ncclSocketClose(&sock)); // 关闭套接字
NCCLCHECK(ncclSocketClose(&listenSockRoot)); // 关闭根节点的监听套接字
// 初始化与下一个节点的发送套接字
NCCLCHECK(ncclSocketInit(&state->ringSendSocket, &nextAddr, comm->magic, ncclSocketTypeBootstrap, comm->abortFlag)); // 初始化套接字
NCCLCHECK(ncclSocketConnect(&state->ringSendSocket)); // 连接到下一个节点
// 接受来自前一个节点的环连接请求
NCCLCHECK(ncclSocketInit(&state->ringRecvSocket)); // 初始化套接字
NCCLCHECK(ncclSocketAccept(&state->ringRecvSocket, &state->listenSock)); // 接受连接请求
///
4、AllGather全局收集所有节点的网络地址;
// 全局收集所有节点的监听器地址
NCCLCHECK(ncclCalloc(&state->peerCommAddresses, nranks)); // 分配内存
NCCLCHECK(ncclSocketGetAddr(&state->listenSock, state->peerCommAddresses+rank)); // 获取当前节点的监听器地址
NCCLCHECK(bootstrapAllGather(state, state->peerCommAddresses, sizeof(union ncclSocketAddress))); // 全局收集监听器地址
// 创建服务代理套接字
NCCLCHECK(ncclCalloc(&state->peerProxyAddresses, nranks)); // 分配内存
NCCLCHECK(ncclCalloc(&state->peerProxyAddressesUDS, nranks)); // 分配内存
// 初始化服务代理
NCCLCHECK(ncclCalloc(&proxySocket, 1)); // 分配内存
NCCLCHECK(ncclSocketInit(proxySocket, &bootstrapNetIfAddr, comm->magic, ncclSocketTypeProxy, comm->abortFlag)); // 初始化套接字
NCCLCHECK(ncclSocketListen(proxySocket)); // 设置监听状态
NCCLCHECK(ncclSocketGetAddr(proxySocket, state->peerProxyAddresses+rank)); // 获取当前节点的代理地址
NCCLCHECK(bootstrapAllGather(state, state->peerProxyAddresses, sizeof(union ncclSocketAddress))); // 全局收集代理地址
uint64_t randId; // 随机ID
NCCLCHECK(getRandomData(&randId, sizeof(randId))); // 生成随机数据
state->peerProxyAddressesUDS[rank] = getPidHash()+randId; // 生成唯一的UDS名称
NCCLCHECK(bootstrapAllGather(state, state->peerProxyAddressesUDS, sizeof(*state->peerProxyAddressesUDS))); // 全局收集UDS名称
NCCLCHECK(ncclProxyInit(comm, proxySocket, state->peerProxyAddresses, state->peerProxyAddressesUDS)); // 初始化代理
// 记录完成初始化的消息
TRACE(NCCL_INIT, "rank %d nranks %d - DONE", rank, nranks);
// 返回成功状态
return ncclSuccess;
}