Batch Normalization
论文地址:https://arxiv.org/abs/1502.03167
Abstract
深度网络训练时,每一层的输入都是前一层的输出
o
u
t
=
W
T
∗
X
out = W^T*X
out=WT∗X
所以当进行反向传播时,W会根据梯度下降算法进行更新,而它的下一层的数据分布也会因此改变,这样会耗费很多时间学习数据的分布变化,降低学习速度。
我们将这种现象称为内部协变量转移(internal covariate shift),解决方法是对每层输入进行归一化。
批标准化使我们能够使用更高的学习率,并且不用太注意初始化。它也作为一个正则化项,在某些情况下不需要Dropout。
将BN应用到最先进的图像分类模型上,它取得了相同精度的前提下,减少了14倍的训练步骤,并以显著的差距击败了原始模型。
使用批标准化网络的组合,我们改进了在ImageNet分类上公布的最佳结果:达到了4.9% top-5的验证误差(和4.8%测试误差),超过了人类评估者的准确性。
Introduction
虽然SGD十分简单且高效,但是它需要小心的调整模型的超参数,尤其是学习率和参数初始化值。
由于每层都要受到上一层参数的影响,随着网络变深,网络参数的微小变化会被放大,需要不断去适应新的分布。
BN想把输入的均值方差规范化,使输入分布一致。
Towards Reducing Internal Covariate Shift
-
白化会加快收敛
-
规范化与某个样本的各层输入及所有样本的各层输入都有关(对某个规范化时用到了所有样本)
x ′ = N o r m ( x , X ) x'=Norm(x,X) x′=Norm(x,X)
在反向传播时,求导数需要考虑以下两项:
∂ N o r m ( x , X ) ∂ x a n d ∂ N o r m ( x , X ) ∂ X \frac{\partial Norm(x,X)}{\partial x} \ and \ \frac{\partial Norm(x,X)}{\partial X} ∂x∂Norm(x,X) and ∂X∂Norm(x,X)
这样基于整个训练集的白化是非常耗时的,因为白化需要计算 x 的协方差矩阵及白化部分,还需计算BP算法中的求导。
但是基于某个或者部分样本进行规范化又会changes the representation ability of a network
所以本文在minibatch内归一化,再用可以学习的 γ 和 β 来拟合minibatch的统计量与整个训练集统计量之间的关系。
Normalization via Mini-Batch Statistics
优化方案:
-
既然白化计算过程比较复杂,那我们就简化一点,比如我们可以尝试单独对每个特征进行normalizaiton就可以了,让每个特征都有均值为0,方差为1的分布就OK。(参考:https://zhuanlan.zhihu.com/p/34879333)
-
另一个问题,既然白化操作减弱了网络中每一层输入数据表达能力,那我就再加个线性变换操作,让这些数据再能够尽可能恢复本身的表达能力就好了。
simply normalizing each input of a layer may change what the layer can represent.normalizing the inputs of a sigmoid would constrain them to the linear regime of the nonlinearity
单一的normalization可能会改变一个层代表的含义,例如标准化sigmoid的输入会将他们约束到非线性的状态,我们要保证插入到网络的变化可以表示恒等变换。
为了解决这个问题,we make sure that the transformation inserted in the network can represent the identity transform.也就是用用可以学习的 γ 和 β 去拟合出与原先等价的变换。
y ( k ) = γ ( K ) x ′ ( k ) + β ( k ) y^{(k)}=\gamma^{(K)}x'^{(k)}+\beta^{(k)} y(k)=γ(K)x′(k)+β(k)
其中 γ ( k ) = V a r [ x ( k ) ] \gamma^{(k)}=\sqrt{Var[x^{(k)}]} γ(k)=Var[x(k)], β ( k ) = E [ x ( k ) ] \beta^{(k)}=E[x^{(k)}] β(k)=E[x(k)]
看了一个博客,对于这块写的很好,之前一直不理解bn均值后的尺度:
假设我们有 N 个样本,每个样本通道数为 C,高为 H,宽为 W。对其求均值和方差时,将在 N、H、W上操作,而保留通道 C 的维度。具体来说,就是把第1个样本的第1个通道,加上第2个样本第1个通道 … 加上第 N 个样本第1个通道,求平均,得到通道 1 的均值(注意是除以 N×H×W 而不是单纯除以 N,最后得到的是一个代表这个 batch 第1个通道平均值的数字,而不是一个 H×W 的矩阵)。求通道 1 的方差也是同理。对所有通道都施加一遍这个操作,就得到了所有通道的均值和方差。具体公式为:
如果把 输入:NCW*H 比为一摞书,这摞书总共有 N 本,每本有 C 页,每页有 H 行,每行 W 个字符。BN 求均值时,相当于把这些书按页码一一对应地加起来(例如第1本书第36页,第2本书第36页…),再除以每个页码下的字符总数:N×H×W,因此可以把 BN 看成求“平均书”的操作(注意这个“平均书”每页只有一个字),求标准差时也是同理。
参考
http://noahsnail.com/2017/09/04/2017-9-4-Batch Normalization论文翻译——中文版/