reduce_from_model_parallel_region 和gather_from_model_parallel_region的异同

reduce_from_model_parallel_regiongather_from_model_parallel_region 都是用于在多 GPU 模型并行环境中合并数据的函数,但它们的具体操作和用途有所不同。

reduce_from_model_parallel_region

这个函数的主要作用是将多个 GPU 上的张量进行加和操作。所有参与计算的 GPU 最终都会获得相同的加和结果。

用途
  • 主要用于需要对多个 GPU 上的部分结果进行累加操作的场景。
  • 每个 GPU 上的数据进行加和,最终每个 GPU 都会得到相同的合并结果。
例子

假设有两个 GPU,各自计算出部分结果 output_parallel_0output_parallel_1

  • 使用 reduce_from_model_parallel_region 后,两个 GPU 上的结果都会是 output_parallel_0 + output_parallel_1
gather_from_model_parallel_region

这个函数的主要作用是将多个 GPU 上的张量收集(各个张量在最后一维连接)到一个指定的 GPU 上。其他 GPU 不会得到完整的结果。

代码注释原文:Gather tensors and concatinate along the last dimension.

用途
  • 主要用于需要将多个 GPU 上的部分结果收集到一个 GPU 上进行后续处理的场景。
  • 一个指定的 GPU(通常是 rank 0)收集所有 GPU 上的数据,其他 GPU 不会得到完整的合并结果。
例子

假设有两个 GPU,各自计算出部分结果 output_parallel_0output_parallel_1

  • 使用 gather_from_model_parallel_region 后,一个指定的 GPU(如 rank 0)上会得到 output_parallel_0output_parallel_1 的合并结果,其他 GPU 上不会得到这个结果。

函数实现和使用场景的区别

reduce_from_model_parallel_region
def reduce_from_model_parallel_region(output_parallel: torch.Tensor) -> torch.Tensor:
    """Reduce across all the model parallel GPUs."""
    if get_model_parallel_world_size() == 1:
        return output_parallel

    # Bypass the function if we are using only 1 GPU.
    torch.distributed.all_reduce(output_parallel, group=get_model_parallel_group())

    return output_parallel

实现:使用 torch.distributed.all_reduce 对所有 GPU 上的张量进行加和操作。

  • 使用场景:需要所有 GPU 上都得到相同的加和结果。
gather_from_model_parallel_region
def gather_from_model_parallel_region(output_parallel: torch.Tensor) -> torch.Tensor:
    """Gather tensors and concatinate along the last dimension."""
    world_size = get_model_parallel_world_size()
    if world_size == 1:
        return output_parallel

    # Size and dimension.
    numel = output_parallel.numel()
    output_tensor = torch.empty(numel * world_size, dtype=output_parallel.dtype,
                                device=output_parallel.device)
    # Do the all-gather.
    torch.distributed._all_gather_base(output_tensor, output_parallel.contiguous(),
                                       group=get_model_parallel_group())

    # If we are using the original approach, then this could be something like:
    # tensor_list = [torch.empty_like(output_parallel) for _ in range(world_size)]
    # torch.distributed.all_gather(tensor_list, output_parallel, group=get_model_parallel_group())
    # output_tensor = torch.cat(tensor_list, dim=-1)

    return output_tensor
  • 实现:使用 torch.distributed._all_gather_basetorch.distributed.all_gather 将所有 GPU 上的张量收集到一个张量中。
  • 使用场景:需要将所有 GPU 上的张量收集到一个指定的 GPU 上进行后续处理。

总结

  • reduce_from_model_parallel_region:将多个 GPU 上的张量进行加和操作,所有 GPU 最终得到相同的结果。
  • gather_from_model_parallel_region:将多个 GPU 上的张量收集到一个指定的 GPU 上,其他 GPU 不会得到完整的结果。
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值