使用torch.autograd.function解决dist.all_gather不能反向传播问题

1. 问题来源

最近在用mmcv复现Partial FC模型,看到源码中,有单独写的前向反向传播,甚是疑惑~
源码:

 # Features all-gather
total_features = torch.zeros(features.size()[0] * cfg.world_size,
                             cfg.embedding_size,
                             device=local_rank)
dist.all_gather(list(total_features.chunk(cfg.world_size, dim=0)),    
                features.data)
total_features.requires_grad = True

...
#计算 loss 以及 backward
...

if total_features.grad is not None:
    total_features.grad.detach_()
x_grad = torch.zeros_like(features)

# Feature gradient all-reduce
dist.reduce_scatter(
    x_grad, list(total_features.grad.chunk(cfg.world_size, dim=0)))
x_grad.mul_(cfg.world_size)
# Backward backbone
features.backward(x_grad)
optimizer.step()

2. all_gather做了啥

dist.all_gather官方样例:

>>> # All tensors below are of torch.int64 dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)]
>>> tensor_list
[tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1, 2]), tensor([3, 4])] # Rank 0
[tensor([1, 2]), tensor([3, 4])] # Rank 1

发现 dist.all_gather的输出值 tensor_list,是在all_gather计算之前 需要新初始化的一个list,虽然前向传播时 tensor–>tensor_list, 但是反向传播时,由于 tensor_list 是新初始化的叶子结点,所以并不能实现 tensor_list.grad–>tensor.grad 。

所以解决这个问题,需要将该部分的反向传播搭建起来。

3. 解决方案

(1)第一种方法就是像PFC源码里写的那样,反向传播时分段处理:

# 1. 直接loss反向传播
loss.backward()
...
# 2. 在all_gather部分,对梯度进行衔接
if tensor_list.grad is not None:
    tensor_list.grad.detach_()
x_grad = torch.zeros_like(tensor)

# 将梯度对应分配到各个GPU上
# Feature gradient all-reduce
dist.reduce_scatter(
    x_grad, list(tensor_list.grad.chunk(cfg.world_size, dim=0)))
x_grad.mul_(cfg.world_size)

# 3. 剩余部分反向传播 Backward backbone
tensor.backward(x_grad)
# 梯度更新
optimizer.step()

(2)自定义torch.autograd.function类
对于这些不可自动求导的操作,pytorch给出了扩展 torch.autograd.function 来实现自定义求导方式,pytorch文档里也给出了使用样例:

class Exp(Function):
	# 定义一些前向骚操作
    @staticmethod
    def forward(ctx, i):
        result = i.exp()
        ctx.save_for_backward(result)
        return result
    # 前向操作太骚,只好自己写反传啦~。~
	@staticmethod
    def backward(ctx, grad_output):
        result, = ctx.saved_tensors
        return grad_output * result

# 应用时就可以这么搞拉
#Use it by calling the apply method:
output = Exp.apply(input)

所以前面那个问题就可以解决啦:

class BwFunction(Function):

    @staticmethod
    def forward(ctx, x):
        world_size = dist.get_world_size()
        total_features = torch.zeros(x.size()[0]*world_size, x.size()[1], device=x.device)       
        dist.all_gather(list(total_features.chunk(world_size, dim=0)), x.data)  
        total_features.requires_grad = True
        return total_features

    @staticmethod
    def backward(ctx, grad_output): 
        
        world_size = dist.get_world_size()
        grad_x = None
        
        if grad_output is not None:
            grad_output.detach_()
            x_grad = torch.zeros_like(x)

            # Feature gradient all-reduce
            dist.reduce_scatter(
                x_grad, list(grad_output.chunk(world_size, dim=0)))
            x_grad.div_(world_size)

        grad_x = x_grad

        return grad_x

拖了这么久,好不容易写完,撒花!

参考:

  1. https://github.com/deepinsight/insightface/tree/master/recognition/partial_fc
  2. https://discuss.pytorch.org/t/will-dist-all-gather-break-the-auto-gradient-graph/47350
  3. https://pytorch.org/docs/stable/autograd.html?highlight=autograd#module-torch.autograd
  4. https://blog.csdn.net/Hungryof/article/details/78346304
  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值