def _reduce(tensor: Tensor) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor
dist.all_reduce(tensor,
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=False)
return tensor
- 对张量进行分割
def _split(tensor: Tensor, dim: int = -1) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor
split_size = divide(tensor.shape[dim], gpc.get_world_size(ParallelMode.TENSOR))
tensor_list = torch.split(tensor, split_size, dim=dim)
output = tensor_list[gpc.get_local_rank(ParallelMode.TENSOR)].contiguous()
return output
- tensor_list:表示存储了分割后的小张量的张量列表。
- gpc.get_local_rank(ParallelMode.TENSOR):获取当前进程在张量模式下的本地排名,即该进程在当前并行模式组中的排名。
- tensor_list[gpc.get_local_rank(ParallelMode.TENSOR)]:根据本地排名选择对应的分割后的小张量。
- .contiguous():调用contiguous()方法,确保所选的小张量是连续的,即内存中的数据是连续存储的,方便后续计算操作。
- 在分布式环境下收集各个进程的张量数据并进行合并
def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor
if dim == 1 and list(tensor.shape)[0] == 1:
output_shape = list(tensor.shape)
output_shape[1] *= gpc.get_world_size(ParallelMode.TENSOR)
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
tensor_list = output.chunk(gpc.get_world_size(ParallelMode.TENSOR), dim=1)
dist.all_gather(list(tensor_list),
tensor,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=False)
else:
tensor_list = [
torch.empty_like(tensor) for _ in range(gpc.get_world_size(ParallelMode.TENSOR))
]
dist.all_gather(tensor_list,
tensor,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=False)
output = torch.cat(tensor_list, dim=dim)
return output