1. 问题来源
最近在用mmcv复现Partial FC模型,看到源码中,有单独写的前向反向传播,甚是疑惑~
源码:
# Features all-gather
total_features = torch.zeros(features.size()[0] * cfg.world_size,
cfg.embedding_size,
device=local_rank)
dist.all_gather(list(total_features.chunk(cfg.world_size, dim=0)),
features.data)
total_features.requires_grad = True
...
#计算 loss 以及 backward
...
if total_features.grad is not None:
total_features.grad.detach_()
x_grad = torch.zeros_like(features)
# Feature gradient all-reduce
dist.reduce_scatter(
x_grad, list(total_features.grad.chunk(cfg.world_size, dim=0)))
x_grad.mul_(cfg.world_size)
# Backward backbone
features.backward(x_grad)
optimizer.step()
2. all_gather做了啥
dist.all_gather官方样例:
>>> # All tensors below are of torch.int64 dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)]
>>> tensor_list
[tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1, 2]), tensor([3, 4])] # Rank 0
[tensor([1, 2]), tensor([3, 4])] # Rank 1
发现 dist.all_gather的输出值 tensor_list,是在all_gather计算之前 需要新初始化的一个list,虽然前向传播时 tensor–>tensor_list, 但是反向传播时,由于 tensor_list 是新初始化的叶子结点,所以并不能实现 tensor_list.grad–>tensor.grad 。
所以解决这个问题,需要将该部分的反向传播搭建起来。
3. 解决方案
(1)第一种方法就是像PFC源码里写的那样,反向传播时分段处理:
# 1. 直接loss反向传播
loss.backward()
...
# 2. 在all_gather部分,对梯度进行衔接
if tensor_list.grad is not None:
tensor_list.grad.detach_()
x_grad = torch.zeros_like(tensor)
# 将梯度对应分配到各个GPU上
# Feature gradient all-reduce
dist.reduce_scatter(
x_grad, list(tensor_list.grad.chunk(cfg.world_size, dim=0)))
x_grad.mul_(cfg.world_size)
# 3. 剩余部分反向传播 Backward backbone
tensor.backward(x_grad)
# 梯度更新
optimizer.step()
(2)自定义torch.autograd.function类
对于这些不可自动求导的操作,pytorch给出了扩展 torch.autograd.function 来实现自定义求导方式,pytorch文档里也给出了使用样例:
class Exp(Function):
# 定义一些前向骚操作
@staticmethod
def forward(ctx, i):
result = i.exp()
ctx.save_for_backward(result)
return result
# 前向操作太骚,只好自己写反传啦~。~
@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_tensors
return grad_output * result
# 应用时就可以这么搞拉
#Use it by calling the apply method:
output = Exp.apply(input)
所以前面那个问题就可以解决啦:
class BwFunction(Function):
@staticmethod
def forward(ctx, x):
world_size = dist.get_world_size()
total_features = torch.zeros(x.size()[0]*world_size, x.size()[1], device=x.device)
dist.all_gather(list(total_features.chunk(world_size, dim=0)), x.data)
total_features.requires_grad = True
return total_features
@staticmethod
def backward(ctx, grad_output):
world_size = dist.get_world_size()
grad_x = None
if grad_output is not None:
grad_output.detach_()
x_grad = torch.zeros_like(x)
# Feature gradient all-reduce
dist.reduce_scatter(
x_grad, list(grad_output.chunk(world_size, dim=0)))
x_grad.div_(world_size)
grad_x = x_grad
return grad_x
拖了这么久,好不容易写完,撒花!
参考:
- https://github.com/deepinsight/insightface/tree/master/recognition/partial_fc
- https://discuss.pytorch.org/t/will-dist-all-gather-break-the-auto-gradient-graph/47350
- https://pytorch.org/docs/stable/autograd.html?highlight=autograd#module-torch.autograd
- https://blog.csdn.net/Hungryof/article/details/78346304