SyncBN及其pyTorch实现

1 SyncBN原理

1.1 BN

BN操作可以描述成:
y = x − E [ x ] v a r [ x ] + ϵ ∗ γ + β y = \frac{x - E[x]}{\sqrt{var[x] + \epsilon}}* \gamma + \beta y=var[x]+ϵ xE[x]γ+β

使用BN的好处是:训练时在网络内部进行了归一化,为训练过程提供了正则化,防止了中间层feature map的协方差偏移,有助于抑制过拟合。使用BN,不需要特别依赖于初始化参数,可以使用较大的学习率,因此可以加速模型的训练过程。

下面内容参考自:https://zhuanlan.zhihu.com/p/40496177
参考沐神的帖子,BN起作用的原因可以解释成:
在这里插入图片描述BN的强大效果来自于反向传播时均值和方差对输入的梯度,即 d l d u \frac{d l}{d u} dudl d l d σ 2 \frac{d l}{d \sigma^2} dσ2dl.

1.2 数据并行

深度学习进行模型训练时,在多卡平台上往往使用数据并行,即将模型复制到各显卡上,数据等分成N份,在各卡上单独进行训练。那么如果模型中应用了BN,也是在各数据子集上计算均值和方差,然后进行规范化操作,如下图所示:
在这里插入图片描述我们知道,在训练过程中,使用更大的batchsize有助于使得训练过程比较稳定的收敛,也有助于取得更好的训练效果。但是上面这种标准的BN使用各数据子集进行规范化的,batchsize值较小。那么也就出现了SyncBN提升BN计算时的有效batchsize的大小

1.3 SyncBN

SyncBN即在数据并行分布式训练时,将各显卡上分配的数据汇总到一起进行规范化处理,也就是计算所有数据全局的均值和方差。按照沐神的说法,如果分别计算均值和方差,就需要做两次数据同步。可以分布求全局的 ∑ x i \sum {x_i} xi ∑ x i 2 \sum {x_i}^2 xi2,利用公式 v a r [ x ] = E [ x 2 ] − E 2 [ x ] var[x] = E[x^2] - E^2[x] var[x]=E[x2]E2[x]得到方差。这样就只需要做一次数据同步。
在这里插入图片描述

2 pyTorch实现

pyTorch中SyncBN的实现见:https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html?highlight=torch%20nn%20syncbatchnorm#torch.nn.SyncBatchNorm

torch.nn.SyncBatchNorm(num_features: int, eps: float = 1e-05, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, process_group: Optional[Any] = None)

计算的是同一个进程组中所有mini-batch数据在各通道上的均值和方差。

γ \gamma γ β \beta β是大小为C(C也是输入数据的数量)的可学习参数向量。 γ \gamma γ的元素采样自 u ( 0 , 1 ) u(0,1) u(0,1) β \beta β初始化为0.

训练过程中,均值和方差是求得移动平均值,具体公式为 x ^ n e w = ( 1 − m o m e n t u m ) × x ^ + m o m e n t u m × x t \hat{x}_{new} = (1 - momentum) \times \hat x + momentum \times x_t x^new=(1momentum)×x^+momentum×xt x ^ \hat x x^是历史统计值, x t x_t xt是最新观测值。momentum默认值是0.1。最终求得的统计平均值用于模型推理过程中。但是如果设置track_running_stats = False,就不会在模型训练过程中求均值和方差的移动平均了,训练和推理的时候用的都是一个batch数据的均值和方差。

当前SyncBatchNorm仅支持在DDP模式下使用,且要求每个显卡部署一个进程。可以使用下面介绍的torch.nn.SyncBatchNorm.convert_sync_batchnorm()函数在DDP模式下进行网络训练前将BatchNorm*D调整为SyncBatchNorm

参数:



        num_features – C from an expected input of size (N,C,+)

        eps – a value added to the denominator for numerical stability. Default: 1e-5

        momentum – the value used for the running_mean and running_var computation. Can be set to None for cumulative moving average (i.e. simple average). Default: 0.1

        affine – a boolean value that when set to True, this module has learnable affine parameters. Default: True

        track_running_stats – a boolean value that when set to True, this module tracks the running mean and variance, and when set to False, this module does not track such statistics, and initializes statistics buffers running_mean and running_var as None. When these buffers are None, this module always uses batch statistics. in both training and eval modes. Default: True

        process_group – synchronization of stats happen within each process group individually. Default behavior is synchronization across the whole world(默认是所有进程进行SyncBN)

转换函数:

classmethod convert_sync_batchnorm(module, process_group=None)
Helper function to convert all BatchNorm*D layers in the model to torch.nn.SyncBatchNorm layers.


Parameters:

        module (nn.Module) – module containing one or more attr:BatchNorm*D layers

        process_group (optional)process group to scope synchronization, default is the whole world


Returns:
		The original module with the converted torch.nn.SyncBatchNorm layers. If the original module is a BatchNorm*D layer, a new torch.nn.SyncBatchNorm layer object will be returned instead.

示例代码:

>>> # With Learnable Parameters
>>> m = nn.SyncBatchNorm(100)
>>> # creating process group (optional)
>>> # process_ids is a list of int identifying rank ids.
>>> process_group = torch.distributed.new_group(process_ids)
>>> # Without Learnable Parameters
>>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
>>> input = torch.randn(20, 100, 35, 45, 10)
>>> output = m(input)

>>> # network is nn.BatchNorm layer
>>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group)
>>> # only single gpu per process is currently supported
>>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
>>>                         sync_bn_network,
>>>                         device_ids=[args.local_rank],
>>>                         output_device=args.local_rank)
  • 18
    点赞
  • 42
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值