1. 问题概述
现在Pytorc下进行多卡训练主流的是采用torch.nn.parallel.DistributedDataParallel()
(DDP)方法,但是在一些特殊的情况下这样的方法就使用不了了,特别是在进行与GAN相关的训练的时候,假如使用的损失函数是 WGAN-GP(LP),DRAGAN
,那么其中会用到基于梯度的惩罚,其使用到的函数为torch.autograd.grad()
,但是很不幸的是在实验的过程中该函数使用DDP会报错:
File "/home/work/anaconda3/envs/xxxxx_py/lib/python3.6/site-packages/torch/autograd/__init__.py", line 93, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: derivative for batch_norm_backward_elemt is not implemented
那么需要并行(单机多卡)计算那么就只能使用torch.nn.DataParallel()
了,但是也带来另外一个问题那就是负载极其不均衡,使用这个并行计算方法会在主GPU上占据较多的现存,而其它的GPU显存则只占用了一部分,这样就使得无法再继续增大batchsize了,下图就是这种方式进行计算,整个数据流的路线:
可以在上图中看到输入数据计算和损失计算过程中都会存在数据汇总的情况,这就难免使得主卡的显存爆掉,为了解决这样的问题一个思想就是其网络前向、计算损失的过程都采用并行的方式进行,其流程如下:
这样就可以解决显卡利用率不高的问题,下面给出一些可以参考的负载均衡代码:
2. 代码实现
基于上述内容中的工作,这里将这个的并行过程汇集到一个文件里面,这样可以很方便将其当做是模块使用。
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
"""Encoding Data Parallel"""
import threading
import functools
import torch
from torch.autograd import Variable, Function
import torch.cuda.comm as comm
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.parallel_apply import get_a_var
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
torch_ver = torch.__version__[:3]
__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion',
'patch_replication_callback']
def allreduce(*inputs):
"""Cross GPU all reduce autograd operation for calculate mean and
variance in SyncBN.
"""
return AllReduce.apply(*inputs)
class AllReduce(Function):
@staticmethod
def forward(ctx, num_inputs, *inputs):
ctx.num_inputs = num_inputs
ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
inputs = [inputs[i:i + num_inputs]
for i in range(0, len(inputs), num_inputs)]
# sort before reduce sum
inputs = sorted(inputs, key=lambda i: i[0].get_device())
results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
outputs = comm.broadcast_coalesced(results, ctx