【源码阅读】分布式通信部分代码阅读

#用于分布式环境下进行all-reduce的操作
def _reduce(tensor: Tensor) -> Tensor:
    if gpc.get_world_size(ParallelMode.TENSOR) == 1:
        return tensor

    dist.all_reduce(tensor,
                    op=dist.ReduceOp.SUM,
                    group=gpc.get_group(ParallelMode.TENSOR),#指定用于通信的进程组,通过gpc.get_group(ParallelMode.TENSOR)获取张量模式下的并行模式组
                    async_op=False)#同步执行全局求和操作,即在该操作完成之前会阻塞程序继续执行,直到所有进程完成全局求和操作

    return tensor
  1. 对张量进行分割
#用于在分布式环境下将张量沿指定维度分割成多份
def _split(tensor: Tensor, dim: int = -1) -> Tensor:
    if gpc.get_world_size(ParallelMode.TENSOR) == 1:
        return tensor
    #根据张量在指定维度上的大小以及当前并行环境中的进程数,计算出每个进程需要处理的分割大小,以便进行正确的张量分割操作
    split_size = divide(tensor.shape[dim], gpc.get_world_size(ParallelMode.TENSOR))#gpc.get_world_size(ParallelMode.TENSOR)获取张量模式下的并行模式组的总进程数,即世界大小
    tensor_list = torch.split(tensor, split_size, dim=dim)
    
    #实现了根据当前进程在张量模式下的本地排名,选择对应的分割后的小张量,并确保所选的小张量是连续存储的,以便进行后续计算操作
    output = tensor_list[gpc.get_local_rank(ParallelMode.TENSOR)].contiguous()

    return output
  • tensor_list:表示存储了分割后的小张量的张量列表。
  • gpc.get_local_rank(ParallelMode.TENSOR):获取当前进程在张量模式下的本地排名,即该进程在当前并行模式组中的排名。
  • tensor_list[gpc.get_local_rank(ParallelMode.TENSOR)]:根据本地排名选择对应的分割后的小张量。
  • .contiguous():调用contiguous()方法,确保所选的小张量是连续的,即内存中的数据是连续存储的,方便后续计算操作。
  1. 在分布式环境下收集各个进程的张量数据并进行合并
#用于在分布式环境下收集各个进程的张量数据并进行合并
def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
    if gpc.get_world_size(ParallelMode.TENSOR) == 1:#检查是否为1,为1表示为单设备环境,没有多个进程参与计算
        return tensor

    if dim == 1 and list(tensor.shape)[0] == 1:
        output_shape = list(tensor.shape)
        output_shape[1] *= gpc.get_world_size(ParallelMode.TENSOR)
        output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
        tensor_list = output.chunk(gpc.get_world_size(ParallelMode.TENSOR), dim=1)
        dist.all_gather(list(tensor_list),
                        tensor,
                        group=gpc.get_group(ParallelMode.TENSOR),
                        async_op=False)
    else:
        tensor_list = [
            torch.empty_like(tensor) for _ in range(gpc.get_world_size(ParallelMode.TENSOR))
        ]
        dist.all_gather(tensor_list,
                        tensor,
                        group=gpc.get_group(ParallelMode.TENSOR),
                        async_op=False)
        output = torch.cat(tensor_list, dim=dim)

    return output
  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值