1. 官网介绍
torch.distributed.all_gather() 官网链接
all_gather(tensor_list,tensor,group=None,async_op=False):
tensor_list每个元素代表每个rank的数据,tensor代表每个进程中的tensor数据,其中tensor_list每个分量的维度要与对应的tensor参数中每个rank的维度相同。
官网源代码:
def all_gather(tensor_list,
tensor,
group=None,
async_op=False):
"""
Gathers tensors from the whole group in a list.
Complex tensors are supported.
Args:
tensor_list (list[Tensor]): Output list. It should contain
correctly-sized tensors to be used for output of the collective.
tensor (Tensor): Tensor to be broadcast from current process.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
async_op (bool, optional): Whether this op should be an async op
Returns:
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group
Examples:
>>> # All tensors below are of torch.int64 dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)]
>>> tensor_list
[tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1, 2]), tensor([3, 4])] # Rank 0
[tensor([1, 2]), tensor([3, 4])] # Rank 1
>>> # All tensors below are of torch.cfloat dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zero(2, dtype=torch.cfloat) for _ in range(2)]
>>> tensor_list
[tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
>>> tensor
tensor([1.+1.j, 2.+2.j]) # Rank 0
tensor([3.+3.j, 4.+4.j]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 0
[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 1
"""
_check_tensor_list(tensor_list, "tensor_list")
_check_single_tensor(tensor, "tensor")
if _rank_not_in_group(group):
return
tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list]
tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)
if group is None:
default_pg = _get_default_group()
work = default_pg.allgather([tensor_list], [tensor])
else:
work = group.allgather([tensor_list], [tensor])
if async_op:
return work
else:
work.wait()
官网例子:
# All tensors below are of torch.int64 dtype.
# We have 2 process groups, 2 ranks.
tensor_list = [torch.zeros(2, dtype=torch.int64) for _ in range(2)]
tensor_list
tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
tensor
dist.all_gather(tensor_list, tensor)
tensor_list
# All tensors below are of torch.cfloat dtype.
# We have 2 process groups, 2 ranks.
tensor_list = [torch.zeros(2, dtype=torch.cfloat) for _ in range(2)]
tensor_list
tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
tensor
dist.all_gather(tensor_list, tensor)
tensor_list
2. all_gather()不进行梯度传播,用于模型test或eval状态
torch.distributed.all_gather 本身是不会进行梯度的反向传播的. 如下面代码
batch_size = 16
rank = int(os.environ.get('OMPI_COMM_WORLD_RANK', '0'))
world_size = int(os.environ.get('OMPI_COMM_WORLD_SIZE', '1'))
bs_each = batch_size // world_size
device_id = int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', '0'))
torch.cuda.set_device(device_id)
torch.distributed.init_process_group(
backend='nccl',
init_method='tcp://localhost:12345',
rank=rank,
world_size=world_size,
)
#
from torch import nn
model = nn.Linear(1, 1, bias=False)
model.weight.data[:] = 1.
model = model.cuda()
x = torch.ones((bs_each, 1), requires_grad=True).cuda()
y = model(x)
ys = [torch.zeros_like(y) for i in range(get_mpi_size())]
#
torch.distributed.all_gather(ys, y)
print(y.grad_fn)
#<MmBackward object at 0x7f2073fc3ba8>
for sub_y in ys:
print(sub_y.grad_fn)
#None
运行该代码,首先,其会打印出没采用 all_gather 的真正的梯度函数y.grad_fn. 然后,调用 all_gather 后,ys 的输出是没有 grad_fn 的,可以理解为其是没有梯度反向传播的.
实际场景中,推荐采用 torch.no_grad() 封装 all_gather 函数,以显式地表明没有梯度进行反向传播.
模板代码:
logits = torch.cat(logits_list, dim=0)
targets = torch.cat(targets_list, dim=0)
# For distributed parallel, collect all data and then run metrics.
if torch.distributed.is_initialized():
logits_gather_list = [torch.zeros_like(logits) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(logits_gather_list, logits)
logits = torch.cat(logits_gather_list, dim=0)
targets_gather_list = [torch.zeros_like(targets) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(targets_gather_list, targets)
targets = torch.cat(targets_gather_list, dim=0)
accuracy, recall, precision, auc = classification_metrics(logits, targets)
3. all_gather()需要进行梯度传播,用于模型train状态
with torch.no_grad():
all_x = [torch.zeros_like(x) for _ in range(world_size)]
torch.distributed.all_gather(all_x, x)
all_x[rank] = x
all_x 包含了所有 GPUs 输出的 x. 所有的 x 都是没有 grad_fn 的,除了当前 GPU 输出的 x,因为 all_x[rank] = x。 然后,即可基于 all_x 和 f 计算损失。
也就是把当前GPU上面的原tensor数据赋值给all_x相应rank索引的地方,从而使all_x[rank]的tensor数据能够计算梯度,从而(所有gpu能够进行)反向传播。
模板代码:
import torch
import torch.distributed as dist
# Dummy code representing the forward pass for some batch of text on one replica.
embeddings = model(batch)
# Gather the embeddings from every replica.
embeddings_list = [torch.ones_like(embeddings) for _ in range(dist.get_world_size())]
dist.all_gather(embeddings_list, embeddings)
# Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.
# with the embeddings produced on this replica, which do have gradients.
embeddings_list[dist.get_rank()] = embeddings
# Finally, concatenate the list of embeddings before computing a loss.
embeddings = torch.cat(embeddings_list)
# I didn't demonstrate how to generate the labels, this will be task-dependent.
loss = some_contrastive_loss(embeddings, labels)
参考链接:
https://github.com/KevinMusgrave/pytorch-metric-learning/issues/10
下面这三个含梯度的all_gather代码也能实现(SimCLR模型的代码):
1.
class GatherLayer(torch.autograd.Function):
'''Gather tensors from all process, supporting backward propagation.
'''
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
output = [torch.zeros_like(input) \
for _ in range(dist.get_world_size())]
dist.all_gather(output, input)
return tuple(output)
@staticmethod
def backward(ctx, *grads):
input, = ctx.saved_tensors
grad_out = torch.zeros_like(input)
grad_out[:] = grads[dist.get_rank()]
return grad_out
使用方式:
allgather = GatherLayer.apply
features_gather = allgather(features) #多张GPU的数据gather到一起
参考链接:
https://i.steer.space/blog/2021/01/pytorch-dist-nccl-backend-allgather-stuck
class SyncFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor):
ctx.batch_size = tensor.shape[0]
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(gathered_tensor, tensor)
gathered_tensor = torch.cat(gathered_tensor, 0)
return gathered_tensor
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
idx_from = torch.distributed.get_rank() * ctx.batch_size
idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
return grad_input[idx_from:idx_to]
使用方式:
allgather = SyncFunction.apply
features_gather = allgather(features) #多张GPU的数据gather到一起
import torch.distributed as dist
class AllGather(torch.autograd.Function):
"""An autograd function that performs allgather on a tensor."""
@staticmethod
def forward(ctx, tensor):
output = [torch.empty_like(tensor) for _ in range(dist.get_world_size())]
torch.distributed.all_gather(output, tensor)
ctx.rank = dist.get_rank()
ctx.batch_size = tensor.shape[0]
return torch.cat(output, dim=0)
@staticmethod
def backward(ctx, grad_output):
return (
grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)],
None,
)
使用方式:
allgather = AllGather.apply
features_gather = allgather(features) #多张GPU的数据gather到一起
参考链接:https://github.com/ArrowLuo/CLIP4Clip
4. 相关链接
- 官方链接:torch.distributed.all_gather()
- Pytorch - 基于torch.distributed.all_gather的梯度反向传播
- PyTorch 多进程分布式训练实战
- 在NCCL后端下Pytorch的distributed.all_gather卡死排查
- PyTorch分布式DPP涉及的基本概念与问题
- PyTorch分布式训练详解教程 scatter, gather & isend, irecv & all_reduce & DDP
- 图解DistributedDataParallel (DDP)的通信方式:gather,all_gather,all_reduce,reduce,scatter