Pytorch中torch.nn.DataParallel负载均衡问题

本文探讨了在Pytorch中使用DataParallel进行多卡训练时遇到的负载不均衡问题,尤其是在处理GAN训练时,由于某些特定操作与DDP不兼容。通过详细分析DataParallel的工作原理和内存占用,提出了一种改进方案,旨在提高GPU利用率。文章提供了负载均衡的代码实现和使用示例。
摘要由CSDN通过智能技术生成

1. 问题概述

现在Pytorc下进行多卡训练主流的是采用torch.nn.parallel.DistributedDataParallel()(DDP)方法,但是在一些特殊的情况下这样的方法就使用不了了,特别是在进行与GAN相关的训练的时候,假如使用的损失函数是 WGAN-GP(LP),DRAGAN,那么其中会用到基于梯度的惩罚,其使用到的函数为torch.autograd.grad(),但是很不幸的是在实验的过程中该函数使用DDP会报错:

File "/home/work/anaconda3/envs/xxxxx_py/lib/python3.6/site-packages/torch/autograd/__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: derivative for batch_norm_backward_elemt is not implemented

那么需要并行(单机多卡)计算那么就只能使用torch.nn.DataParallel()了,但是也带来另外一个问题那就是负载极其不均衡,使用这个并行计算方法会在主GPU上占据较多的现存,而其它的GPU显存则只占用了一部分,这样就使得无法再继续增大batchsize了,下图就是这种方式进行计算,整个数据流的路线:在这里插入图片描述
可以在上图中看到输入数据计算和损失计算过程中都会存在数据汇总的情况,这就难免使得主卡的显存爆掉,为了解决这样的问题一个思想就是其网络前向、计算损失的过程都采用并行的方式进行,其流程如下:
在这里插入图片描述

这样就可以解决显卡利用率不高的问题,下面给出一些可以参考的负载均衡代码:

  1. PyTorch-Encoding
  2. Balanced-DataParallel
  3. parallel.py

2. 代码实现

基于上述内容中的工作,这里将这个的并行过程汇集到一个文件里面,这样可以很方便将其当做是模块使用。

##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
## Created by: Hang Zhang
## ECE Department, Rutgers University
## Email: zhang.hang@rutgers.edu
## Copyright (c) 2017
##
## This source code is licensed under the MIT-style license found in the
## LICENSE file in the root directory of this source tree
##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

"""Encoding Data Parallel"""
import threading
import functools
import torch
from torch.autograd import Variable, Function
import torch.cuda.comm as comm
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.parallel_apply import get_a_var
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast

torch_ver = torch.__version__[:3]

__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion',
           'patch_replication_callback']

def allreduce(*inputs):
    """Cross GPU all reduce autograd operation for calculate mean and
    variance in SyncBN.
    """
    return AllReduce.apply(*inputs)

class AllReduce(Function):
    @staticmethod
    def forward(ctx, num_inputs, *inputs):
        ctx.num_inputs = num_inputs
        ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
        inputs = [inputs[i:i + num_inputs]
                 for i in range(0, len(inputs), num_inputs)]
        # sort before reduce sum
        inputs = sorted(inputs, key=lambda i: i[0].get_device())
        results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
        outputs = comm.broadcast_coalesced(results, ctx
PyTorch是一个基于Python的科学计算库,主要针对深度学习任务。在PyTorchtorch.nn是一个用于构建神经网络模型的模块。 torch.nn模块提供了一系列神经网络层和函数,方便用户构建自定义的神经网络。用户可以通过继承torch.nn.Module类来定义自己的神经网络模型。torch.nn模块常用的类包括各种层(例如全连接层、卷积层、池化层和循环层等)、非线性激活函数和损失函数等。 在使用torch.nn模块构建神经网络时,用户需要实现模型的前向传播函数forward()。该函数定义了输入数据在神经网络的流动方式,即通过层和函数的组合计算输出。在forward()函数,用户可以使用已定义的层和函数进行计算,也可以实现自定义的操作。 torch.nn模块的另一个重要概念是参数(parameter)。参数是模型需要学习的变量,例如网络层的权重和偏置项。用户可以通过在模型定义torch.nn.Parameter对象来创建参数,并在forward()函数进行使用。 除了torch.nn模块外,PyTorch还提供了其他的工具和模块来辅助神经网络的训练和优化过程。例如torch.optim模块包含了各种优化算法,如随机梯度下降(SGD)、Adam等,用于更新模型的参数。torch.utils.data模块提供了数据处理和加载的工具,方便用户使用自己的数据训练模型。 总之,torch.nn模块是PyTorch用于构建神经网络模型的重要组成部分。通过使用torch.nn的各种类和函数,用户可以方便地创建自己想要的神经网络结构,并利用PyTorch强大的计算能力和优化算法来训练和优化模型。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值