ICLR20 - 旷视研究院提出MABN:解决小批量训练BN不稳定的问题

本次先大体翻译MABN的论文,有机会我会把YOLO中的BN换成MABN在小批次上试试效果。

目录

  • 背景
  • 介绍
  • 批归一化中的统计量
  • 滑动平均批归一化
  • 实验
  • 论文地址

背景

  • 批度归一化(Batch Normalization/BN已经成为深度学习领域最常用的技术之一,但他的表现很受批次(Batch Size)限制小批次样本的批统计量(Batch Statistics)十分不稳定,导致训练的收敛速度比较慢,推理性能不好
  • 因此,很多改良的BN方法被提出,大体可以分成这两类:
  1. 通过纠正批统计量来还原BN在样本批次量充足时的性能,但是这些办法全都无法完全恢复BN的性能
  2. 使用实例级的归一化 (instance level normalization),使模型免受批统计量的影响。这一类方法可在一定程度上恢复BN在小批次上的性能,但是目前看来,实例级的归一化方法不能完全满足工业需求,因为这类方法必须在推理过程(inference)引入额外的非线性运算,大幅增加执行开销。
  • 这里也提到一点,因为原始的BN是在全部训练之后使用整个训练数据的统计量,而不是批统计量,因此BN是一个线性算子,可以再推理中融入卷积层。
  • 由下图1可知,实例级归一化的计算时间几乎是普通BN的两倍,因此在小批次训练中恢复BN的性能而不是推理的过程中引入任何一个非线性运算是一项困难但必要的任务
    图1:不同归一化方法的对比

简介

  • 在这篇论文中作者首次发现了在归一化的前向传播(Forward Propagation/FP)和反向传播(Backward Propagation/BP)中,不只有2个,实际上有4个批统计量参与进来。额外的与BP有关的两个批统计量的时候与模型的梯度有关,饰演了正则化模型的角色。
  • 作者做了一个实验发现(见下图2),由于批次小的缘故,在BP过程中梯度相关联的批统计量的方差甚至比大家都知道的批统计量(feature map的均值和方差)更大,所以作者认为与梯度相关联的批统计量的不稳定性是BN在小批次训练表现不好的关键原因。
    图2:批统计量随训练次数变化的趋势
  • 根据上述的分析,作者提出了一种全新的归一化的方法,叫做滑动平均批归一化(Moving Average Batch Normalization/MABN)。MABN无需在推理过程中引入任何非线性操作就可以完全解决小批次问题。MABN的核心思想是用滑动平均统计量moving average statistics/MAS)代替批统计量归一化输出的feature map
  • 本文用不同类型的滑动平均统计量分别代替参与FP和BP的批统计量,并进行理论分析以证明其合理性。但是在实践中发现,直接使用滑动平均统计量代替批统计量无法使模型训练收敛。
  • 作者将训练不收敛问题归因于训练不稳定引起的梯度爆炸。为避免训练爆炸,本文通过减少批统计量的数量,中心化卷积核的权重(weight centralization),并采用重归一化策略(Renormalization)来改进原始批归一化的形式。本文还从理论上证明了已修改的归一化形式比原始归一化形式更稳定

批归一化中的统计量

回顾普通BN

  • 首先假设BN的输入 X ∈ R B × p X \in R^{B×p} XRB×p,其中 B B B代表batch size, p p p代表feature map的数量,所以被归一化的feature map Y Y Y在迭代次数 t t t上被计算为:
    Y = X − μ B t σ B t ( 1 ) Y = \frac{X-\mu B_t}{\sigma B_t} (1) Y=σBtXμBt(1)
  • μ B t \mu B_t μBt σ B t 2 \sigma ^2_{B_t} σBt2是样本的均值和方差,另外参数 γ , β \gamma,\beta γ,β用于对Y进行缩放和偏移:
    Z = Y γ + β ( 3 ) Z = Y\gamma + \beta(3) Z=Yγ+β3
  • 由于这个线性变换的操作默认所有归一化都会添加,所以以下讨论将会省略。
  • 在局部梯度KaTeX parse error: Undefined control sequence: \bracevert at position 30: … L}{\partial Y}\̲b̲r̲a̲c̲e̲v̲e̲r̲t̲_{\theta _t,B_t…给定的情况下,其中 L L L代表损失, θ t \theta _t θt代表迭代 t t t次需要学习的全部参数,那么局部梯度KaTeX parse error: Undefined control sequence: \bracevert at position 30: … L}{\partial X}\̲b̲r̲a̲c̲e̲v̲e̲r̲t̲_{\theta _t,B_t…可以计算为:
    KaTeX parse error: Undefined control sequence: \bracevert at position 30: … L}{\partial X}\̲b̲r̲a̲c̲e̲v̲e̲r̲t̲_{\theta _t,B_t…
  • 其中 ⋅ \cdot 代表逐元素乘积, g B t gB_t gBt, Ψ B t \Psi B_t ΨBt的计算为:
    KaTeX parse error: Undefined control sequence: \bracevert at position 61: …artial Y_{b,:}}\̲b̲r̲a̲c̲e̲v̲e̲r̲t̲_{\theta _t,B_t…
  • 由等式(5)可以知道,在BP的过程中, g B t gB_t gBt, Ψ B t \Psi B_t ΨBt也属于参与BN BP的批统计量,但是以前的研究从没讨论过这两个统计量。

批统计量的不稳定性

  • 根据原始BN的论文,理想的归一化方法是使用基于整个训练数据集的期望和方差进行特征归一化操作:
    Y = X − E X V a r [ X ] ( 6 ) Y = \frac{X-EX}{\sqrt{Var[X]}}(6) Y=Var[X] XEX(6)
  • 但是,在随机梯度下降SGD(Stochastic Gradient Descent/SGD)的情形下使用全数据集的统计量是不切实际的。因此,原始BN使用随机梯度训练中的小批次(mini-batch)计算统计量以代替全数据统计量。这种简化使得将均值和方差纳入反传图中成为可能。
  • 批统计量像 μ B t , σ B t 2 \mu B_t,\sigma^2_{B_t} μBt,σBt2是一种蒙特卡洛估计(Monte Carlo Estimatior),它的方差与样本数量成反比,因此,当批次较小时,批统计量的方差会急剧增加。图2给出了在ImageNet训练期间,ResNet-50的一个具体的归一化层的批统计量的变化。
  • 批统计量包含总体的均值和方差随着模型更新而变化的信息,以便随着权重更新正确地正则化模型梯度,从而在个体样本变化与总体变化的平衡方面发挥重要作用。因此,精确估算总体统计量(population statistics)十分关键
  • 小批次统计量的不稳定性从两个方面降低了模型的性能:
  1. 小批次统计量的不稳定性使训练不稳定,导致收敛缓慢
  2. 小批次统计量的不稳定性会在批统计量和总体统计量之间产生巨大差异
  • 由于模型训练使用批统计量,评估模型使用总体统计量,因此批统计量和总体统计量之间的差异将导致训练和推理不一致,使得模型在评估集上的表现不好。

滑动平均批归一化

  • 根据上述讨论,我们知道恢复BN性能的关键是解决小批次统计量的不稳定性。因此作者给出了两种解决方案:
  1. 使用滑动平均统计量(MAS)估计总体统计量.
  2. 通过改进归一化形式减少统计量的数量。

滑动平均统计量代替批统计量

  • 当批次较小时,MAS似乎可以替代批统计量来估计总体统计量。本文考虑两类MAS:简单移动平均统计量(Simple Moving Average Statistics/SMAS)和指数移动平均统计量(Exponetial Moving Average Statistics/EMAS)。下述定理1表明,在一般条件下,SMAS和EMAS比批统计量更稳定:
    image
  • 定理1既证明了MAS比批统计量有更小的方差,也证明了在统计量收敛的时候(式(8))如果动量 α \alpha α比较大,EMAS优于SMAS,方差更低。FP统计量满足收敛性,所以用SMAS代替FP统计量;BP统计量并不一定满足这一条件,所以BP统计量仅用SMAS代替。理论分析之外,实验还有力地表明MAS替代小批次统计量的有效性。此外作者也说明了本质上BN就是用EMA代替FP批统计量。

通过减少统计量的数量稳定归一化

  • 为了进一步稳定小批次的训练过程,作者考虑用 E X 2 EX^2 EX2而不是 E X EX EX V a r ( X ) Var(X) Var(X)归一化特征图X,归一化等式可以修改为:
    Y = X χ B t , Z = Y ⋅ γ + B t ( 15 ) Y = \frac{X}{\chi B_t},Z=Y \cdot \gamma + B_t(15) Y=χBtXZ=Yγ+Bt15
  • 这样修改的好处也是显然的:在FP和BP期间只剩下两个批统计量,那么相比于原来的归一化形式,改进后的归一化层的不稳定性降低,本文也用了定理2给出修改理论的证明:
    image
  • 但是,因为消去了中心化feature的过程,模型的性能有所降低,我们可以通过增加中心化权重,逆补中心化feature map的损失。

实验

  • 作者在ImageNet、COCO数据集上都进行了测试,并且取得了不错的效果。

ImageNet

表1:ImageNet分类任务中ResNet-50 top-1错误率对比

  • 其中Regular代表batch_size=32,small表示batch_size为2。作者还做了ablation study,验证了MABN每个部分的作用:
    表2:红框里的结果是批次为正常大小时的结果(batch size=32);蓝框里的结果是批次特别小时的结果(batch size=2)
  • 可以看到只有MABN在小批次的情况下达到跟正常批次BN的表现。

COCO

  • 本文按照Mask R-CNN的基本设定在COCO数据集上进行实验,并比较了MABN和它的baseline在不同训练情形下的表现。
    image

论文地址

Towards Stablizing Batch Statistics in Backward Propagation of Batch Normalization

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值