reduce_from_model_parallel_region
和 gather_from_model_parallel_region
都是用于在多 GPU 模型并行环境中合并数据的函数,但它们的具体操作和用途有所不同。
reduce_from_model_parallel_region
这个函数的主要作用是将多个 GPU 上的张量进行加和操作。所有参与计算的 GPU 最终都会获得相同的加和结果。
用途
- 主要用于需要对多个 GPU 上的部分结果进行累加操作的场景。
- 每个 GPU 上的数据进行加和,最终每个 GPU 都会得到相同的合并结果。
例子
假设有两个 GPU,各自计算出部分结果 output_parallel_0
和 output_parallel_1
:
- 使用
reduce_from_model_parallel_region
后,两个 GPU 上的结果都会是output_parallel_0 + output_parallel_1
。
gather_from_model_parallel_region
这个函数的主要作用是将多个 GPU 上的张量收集(各个张量在最后一维连接)到一个指定的 GPU 上。其他 GPU 不会得到完整的结果。
代码注释原文:Gather tensors and concatinate along the last dimension.
用途
- 主要用于需要将多个 GPU 上的部分结果收集到一个 GPU 上进行后续处理的场景。
- 一个指定的 GPU(通常是 rank 0)收集所有 GPU 上的数据,其他 GPU 不会得到完整的合并结果。
例子
假设有两个 GPU,各自计算出部分结果 output_parallel_0
和 output_parallel_1
:
- 使用
gather_from_model_parallel_region
后,一个指定的 GPU(如 rank 0)上会得到output_parallel_0
和output_parallel_1
的合并结果,其他 GPU 上不会得到这个结果。
函数实现和使用场景的区别
reduce_from_model_parallel_region
def reduce_from_model_parallel_region(output_parallel: torch.Tensor) -> torch.Tensor:
"""Reduce across all the model parallel GPUs."""
if get_model_parallel_world_size() == 1:
return output_parallel
# Bypass the function if we are using only 1 GPU.
torch.distributed.all_reduce(output_parallel, group=get_model_parallel_group())
return output_parallel
实现:使用 torch.distributed.all_reduce
对所有 GPU 上的张量进行加和操作。
- 使用场景:需要所有 GPU 上都得到相同的加和结果。
gather_from_model_parallel_region
def gather_from_model_parallel_region(output_parallel: torch.Tensor) -> torch.Tensor:
"""Gather tensors and concatinate along the last dimension."""
world_size = get_model_parallel_world_size()
if world_size == 1:
return output_parallel
# Size and dimension.
numel = output_parallel.numel()
output_tensor = torch.empty(numel * world_size, dtype=output_parallel.dtype,
device=output_parallel.device)
# Do the all-gather.
torch.distributed._all_gather_base(output_tensor, output_parallel.contiguous(),
group=get_model_parallel_group())
# If we are using the original approach, then this could be something like:
# tensor_list = [torch.empty_like(output_parallel) for _ in range(world_size)]
# torch.distributed.all_gather(tensor_list, output_parallel, group=get_model_parallel_group())
# output_tensor = torch.cat(tensor_list, dim=-1)
return output_tensor
- 实现:使用
torch.distributed._all_gather_base
或torch.distributed.all_gather
将所有 GPU 上的张量收集到一个张量中。 - 使用场景:需要将所有 GPU 上的张量收集到一个指定的 GPU 上进行后续处理。
总结
reduce_from_model_parallel_region
:将多个 GPU 上的张量进行加和操作,所有 GPU 最终得到相同的结果。gather_from_model_parallel_region
:将多个 GPU 上的张量收集到一个指定的 GPU 上,其他 GPU 不会得到完整的结果。