Pytorch多GPU和Sync BatchNorm代码

由于复现spade的过程中遇到了一些GPU的问题,所以决定好好理解一下DPL

代码问题

终端显示GPU显示

终端显示暂行不动,并没有报错。GPU没有加载进程,同时CPU也没有动。
考虑可能是DPL的问题。由于代码中使用了Sync BatchNorm,考虑到可能是DPL的问题。

nn.DataParallel

在forward阶段,当前GPU上的module会被复制到其他GPU上,输入数据则会被切分,分别传到不同的GPU上进行计算;在backward阶段,每个GPU上的梯度会被求和并传回当前GPU上,并更新参数。也就是复制module -> forward -> 计算loss -> backward -> 汇总gradients -> 更新参数 -> 复制module -> …的不断重复执行,示意图如下:
DPL示意图因为数据会被均分到不同的GPU上,所以要求batch_size大于GPU的数量。下面对DataParallel的forward函数做一个简单的解释:

class DataParallel(Module):
    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(DataParallel, self).__init__()
        if not torch.cuda.is_available():
            self.module = module
            self.device_ids = []
            return
        if device_ids is None:
            device_ids = list(range(torch.cuda.device_count()))
        if output_device is None:
            output_device = device_ids[0]
        self.dim = dim
        self.module = module   # 待并行计算的模型
        self.device_ids = list(map(lambda x: _get_device_index(x, True),device_ids))
        self.output_device = _get_device_index(output_device, True)
        self.src_device_obj = torch.device("cuda:{}".format(self.device_ids[0]))
        _check_balance(self.device_ids)
        if len(self.device_ids) == 1:
            self.module.cuda(device_ids[0])

    def forward(self, *inputs, **kwargs):
        if not self.device_ids:
            return self.module(*inputs, **kwargs)
        for t in chain(self.module.parameters(), self.module.buffers()):
            if t.device != self.src_device_obj:
                raise RuntimeError("module must have its parameters and buffers "
                                   "on device {} (device_ids[0]) but found one of "
                                   "them on device: {}".format(self.src_device_obj, t.device))

        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)  
        if len(self.device_ids) == 1:
            return self.module(*inputs[0], **kwargs[0])
        replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
        outputs = self.parallel_apply(replicas, inputs, kwargs)
        return self.gather(outputs, self.output_device)

    def scatter(self, inputs, kwargs, device_ids):
        '''scatter_kwargs内部调用名为scatter的函数,
        作用是将输入数据及参数均分到每张卡上,以及复制其他类型对象的引用'''
        return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

    def replicate(self, module, device_ids):
        '''replicate对输入模型的parameters、buffers、modules都一一进行copy,并返回copy的list,
        因为modules最终是以类似链表的形式存储的,所以list中只包含第一个module'''
        return replicate(module, device_ids)
        
    def parallel_apply(self, replicas, inputs, kwargs):
        '''内部调用python的Thread将分割好的input分配到不同的GPU上计算,并返回result dict
        这里会调用SyncBN的前向反馈方法'''
        return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])

    def gather(self, outputs, output_device):
        '''将每张卡的计算结果统一汇聚到主卡,从不同GPU上取回结果'''
        return gather(outputs, output_device, dim=self.dim)

原生的该方法只是将模型在每张卡上复制一份,并且没有建立起联系,而我们的 SyncBN 是需要进行同步的,因此需要重载该方法,让各张卡上的SyncBN 通过某种数据结构和同步机制建立起联系。

重载nn.DataParallel.replicate方法

可以设计一个继承nn.DataParallel的子类DataParallelWithCallBack,重载了replicate方法,子类的该方法先是调用父类的replicate方法,然后调用一个自定义的回调函数(这也是之所以命名为DataParallelWithCallBack的原因),该回调函数用于将各卡对应的 SyncBN 层关联起来,使得它们可以通过某种数据结构进行通信

class DataParallelWithCallback(DataParallel):
    def replicate(self, module, device_ids):
        modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
        execute_replication_callbacks(modules)#回调函数
        return modules

子类重载的replicate方法

def execute_replication_callbacks(modules):
    master_copy = modules[0]#主卡上的模型
    nr_modules = len(list(master_copy.modules()))
    ctxs = [CallbackContext() for _ in range(nr_modules)]
    #上下文数据结构,用于将各卡上对应的SynBN层关联起来
    for i, module in enumerate(modules):
        for j, m in enumerate(module.modules()):
            if hasattr(m, '__data_parallel_replicate__'):#判断SynBN层的条件
                m.__data_parallel_replicate__(ctxs[j], i)#其余各卡向主卡进行注册关联

自定义的回调函数,将各卡对应的Syn-BN层进行关联,其中CallbackContext是一个自定义类,其中没有定义实质性的东西,作为一个上下文数据结构,实例化这个类的对象主要用于将各个卡上对应的Syn-BN层进行关联;__data_parallel_replicate__是在Syn-BN中定义的方法,在该方法中其余子卡上的Syn-BN层会向主卡进行注册,使得主卡能够通过某种数据结构和各卡进行通信。

SynBN的同步注册机制

由上可知,我们需要在 SynBN 中实现一个用于同步的注册方法,SynBN 中还需要设置一个用于管理同步的对象(下图中的 _sync_master),这个对象有一个注册方法,可将子卡注册到其主卡。

在 SyncBN 的方法中,若是主卡,则将上下文管理器的 sync_master 属性设置为这个管理同步的对象(_sync_master);否则,则调用上下文对象的同步管理对象的注册方法,将该卡向其主卡进行注册。

    def __data_parallel_replicate__(self, ctx, copy_id):
        self._is_parallel = True
        self._parallel_id = copy_id
        # parallel_id == 0 means master device.
        if self._parallel_id == 0:
            ctx.sync_master = self._sync_master
        else:
            self._slave_pipe = ctx.sync_master.register_slave(copy_id)#若非主卡则进行注册

SynBN的同步注册机制

    def register_slave(self, identifier):
        if self._activated:
            assert self._queue.empty(), 'Queue is not clean before next initialization.'
            self._queue.queue.clear()#用于主卡和各子卡之间进行数据共享
            self._activated = False
            self._registry.clear()
        future = FutureResult()#存储各子卡的计算结果
        self._registry[identifier] = _MasterRegistry(future)#注册子卡
        return SlavePipe(identifier, self._queue, future)#返回一个子卡用于同步的对象

主卡进行同步管理的类中注册子卡的方法

class SyncMaster(object):#主卡通过实例化该对象来进行同步管理
    def __init__(self, master_callback):
        self._master_callback = master_callback
        self._queue = queue.Queue()#主卡与子卡共享队列传递与获取计算结果
        self._registry = collections.OrderedDict()#通过它识别各子卡
        self._activated = False

主卡进行同步管理的类

_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
class SlavePipe(_SlavePipeBase):
    def run_slave(self, msg):
        self.queue.put((self.identifier, msg))
        ret = self.result.get()
        self.queue.put(True)
        return ret

子卡进行同步操作的类

SynBN的前向反馈

首先,每张卡上的 SyncBN 各自计算出 mini-batch 的均值和与平方和,然后主卡上的 SyncBN 收集来自各个子卡的计算结果,从而计算出全局的均值和方差,接着发放回各个子卡,最后各子卡的 SyncBN 收到来自主卡返回的计算结果各自进行归一化(和缩放平移)操作。当然,主卡上的 SyncBN 计算出全局统计量后就可以进行它的归一化(和缩放平移)操作了。

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值