摘要:
在以往训练深度模型的时候,总是会遇到一个问题:每一层输入到下一层的数据的分布都是不一样的,这就要求在训练阶段只能使用较小的学习率和小心翼翼的设置初始化参数,这种问题称之为内部相关变量的漂移。
贡献1:提出了一种方法Batch Normalization来解决内部相关变量漂移的问题,使得我们可以使用更高的学习率和不需要太关心初始化问题。
贡献2:它的活动就像正则化一样,从某种意义上淘汰了Dropout技术。
贡献3:在当前的最好的模型上训练,得到相同的精度,但是少了14个step。
贡献4:在ImageNet验证集上获得了4.9%的top-5错误率,同时在测试集上获得了4.8%的top-5错误率,这超过了人眼识别的正确率。
一.介绍
随机梯度下降法(SGD),被证明了是一种有效地训练方法,其目的是要优化参数Θ,也就是最小化损失函数:
其中x1.....N是代表了数据。最小批次(mini-batch)是用来近似损失函数的梯度,通过计算
使用样本的最小批次有几个优点:1. 一个批次的梯度是对整个数据集梯度的评估,这就意味着一个批次大小的提升会带来训练结果的提升。2.计算一个有m个样本的批次的效率比单独计算m个样本号好,这是由于计算平台的平行计算原因。
但是SGD需要小心的设置超参数,比如学习率和初始化参数。由于参数是层层传递,一个小小的改动都会随着网络变深而放大。输入的分布不均匀造成网络要一直不断的适应新的分布,也就是网络要一直经历相关变量的漂移。考虑到一个网络的某一部分:
F2,和F1表示的是网络的某一层,Θ1和Θ2分别表示的是网络F1和F2的参数,那么x=F1(u,Θ1)是第一层的输出结果,l=F2(x,Θ2)是第二层的输出结果。那么第二层的参数更新情况可以这样表示:
其中m表示的是batch size,a表示的学习率。从上面我们可以得知数据的分布是对训练的效率有影响的,比如说训练和测试的数据分布是一样的。假设数据的分布一直是一样的,那么F2网络也不用一直调整参数来适应新的数据。
固定的输入对网络是有帮助的。考虑sigmoid函数z=g(Wu + b),g(x) = 1/(1+e(-x))随着|x|的增加,g(x)的倒数S(x)(1-S(x))会变小。这就意味着x = Wu+b期望较小的绝对值,但是因为X是受到W,b影响,在训练当中会使得x沉浸在非线性中并且使得收敛变慢(假如说训练的时候x变大了,个人认为这是一种不可避免的情况),随着网络的增加,这种结果会被影响(x被放大),于是造成了梯度消失问题。引入了ReLU(x) = max(x, 0)函数和小心初始化可以解决这个问题。然而如果我们可以保证输入的分布不变,将可以使得网络避免停滞在饱和状态,并且加速训练。
我们提出了Batch Normalization,主要方法就是固定每层网络输入的均值和变化。同时Batch Normalization有利于减少传播梯度对参数尺度的以来。这使得我们可以用更高的学习率同时避免网络发散。
在4.2部分,我们可以看到使用Batch Normalization只需要7%的以往训练步骤数,并且还提升精度。
二. 向着减少内部相关变量的漂移
从过去的文献中使得数据输入是白的(零均值,减少相关性)。减均值操作可以随着层层网络传递下去,这会使得数据的分布变得比较固定,加速训练。
所以,我们可以考虑如果在每一层都加入归一化操作,效果将会提升。然而这种操作在优化阶段是散置的,这就意味着更新参数的同时还要更新归一化的参数,这将会降低效率。例子如下:
所以可以看到,采用这种方法会使得梯度更新中的网络输出不变,这就意味着loss不会变,所以无限循环下去,不会有任何的优化。这种方法其实是忽略了没有使得规范化与Loss产生影响。
同时采用规范化方法,也就是计算:
这会要求需要计算下面公式:
这需要计算协方差矩阵,需要大量的计算资源。所以这就让我们寻求一种方法,可以使得规范化可导,使得梯度能传播下去,同时又可以不用计算整个数据集。
同时有些方法使用统计方法在一个单一的训练样本计算,或者是在一个图像网络计算不同特征图,使其规范化,但是这会丢失一定信息。因此我们希望我们的方法可以保留信息通过规范化手段。
三.Normalization via Mini-Batch Statistics
因为白化操作计算量大并且不是处处可导的,所以BN算法使用了两种简化方法。第一种简化方法是将激活的输入减去均值然后除以方差,具体公式如下:
但是注意到,如果将输入进行规范化会降低模型的表达能力。比如sigmoid函数,如果再输入sigmoid函数之前规范化,将会使得输入的数据位于sigmoid线性的部分。所以为了解决这个问题,引入两个参数来调节规范化的输入,具体公式如下:
其中,,这样就可以恢复原来的激活值。
补充一点个人理解:一开始看到这里的时候,我感到疑惑,如果把数据规范化后又把它反变回来,那这不是等于没有规范化嘛?后来想了想,其实我觉得这个规范化主要是针对后向传播的,也就是把规范化后的信息去更新紧接着的W权重,至于前向传播则是为了加强网络表达,所以把参数从变换中恢复出来(这样理解不知道对不对)。
第二个简化是为了解决这样的问题:如果每次训练都要依赖于整个数据集,那这将会是一种不切实际的想法。所以让每个mini-batch输出激活值的数学期望和方差,也就是让mini-batch近似的表达整个数据的分布,具体的算法是,考虑一个有m个值的mini-batch,可学习到的变换是,所以具体公式如下:
在后向传播过程中,梯度计算公式如下:
因此到目前为止,internal covariate shift的问题可以较好的解决了,同时能继续的保护网络的表达泛化能力。
3.1 训练和Batch Normalized网络的推导
Batch Normalization依赖于mini-batch可以有效地训练网络,但是在推导阶段不是必要的,具体的训练过程如下:
3.2 Batch-Normalized 卷积神经网络
考虑到网络中存在一种放射变换z = g(Wu+b),在非线性变换之前加入BN算法,也就是公式变成z = g(BN(Wu+b))。至于为什么不是直接加在上一层的输出u上,是因为u是上一层的非线性变换,如果直接加在u上,也就是公式变成z = g(WBN(u)+b),那么这会改变上一层的输出分布,并且不能一处covariate shift效果。相对的Wu+b有更好的对称性,非稀疏性分布,更加的高斯,有比较稳定的分布。
同时在规范化的时候,我们可以不对b规范化,也就是移除b,因为b的作用是有减去的效果,可以通过前面说过的beta参数去代替,所以公式变成z = g(BN(Wu))。
对于卷积层来说,为了使得规范化符合卷积的性质,所以对于一张图上的不同元素,不同的位置是用相同的广泛化方式。也就是说,对于相同的一张特征图有相同的规范化,假设B是一群即将被规范化的一组数据,那么特征图为pxq,mini-batch为m,则B的数量为 m x p x q。
3.3 Batch Normalization 可以拥有更高的学习率
在传统的深度学习中,太高的学习率会导致梯度爆炸或是消失。通过规范化激活值可以解决这个问题。因为他阻止了小小的变化被逐层放大,同时避免了网络陷入局部最优解。
同时批规范化使得网络对于输入的尺度变化有更大的弹性。对于一个数据的尺度放大了a那么输入:
和梯度:
都是不变的,这就说明了Batch Normalizatin能够应对尺度上的变化。
3.4 Batch Normalization 正则化了模型
当用Batch Normalization训练网络的时候,一个训练样本似乎是在mini-batch中连接其他样本,并且训练的网络不再为特定的样本生成确切的值。在我们的实验中,我们发现了这种影响对于泛化网络将会是有利的。过去使用Dropout来解决过拟合问题,如今有了Batch Normalization 的作用可以移除这个Dropout了。
四.实验
4.1 Activations over time
在MINST数据集上做测试,采用一个三层,每层一百单元的全连接网络,训练50000步,mini-batch为60,显示其精度和某一个激活的分布,结果如下:
4.2 ImageNet Classification
将Batch Normalization加入GoogLeNet,同时网络做了一些调整,具体情况如下:
4.2.1加速的BN网络
为了更好地体现算法的优越性,对网络的一些参数做了一些调整:1.加大学习率;2.移除Dropout层;3.减少L2正则化的权重(减少1/5);4.加速了学习率的衰减(6倍);5.对训练集的洗牌更加彻底;6.移除了对图像光计度的扭曲。
4.2.2 单一网络的分类
用了几种不同版本的网络测试:
Inception:原始的网络,学习率0.0015。
BN-Baseline: 在Inception中加入BN算法。
BN-x5:按照4.2.1修改后的网络,学习率增加5倍,变成0.0075。
BN-x30:像BN-x5网络,但是学习率只有0.045。
BN-x5-Sigmoid:像BN-x5网络一样,但是用的是sigmoid激活函数。
结果如下:
4.2.3 集成的分类
结果如下:
五.结论
BN算法加速了网络训练,其原因是因为移除了内部的相关变量漂移。对于每一个mini-batch都执行了Batch Normalization。同时对于每一个激活值只使用了两个变量就保护了网络的表达能力。加入BN算法的网络可以在非线性中训练,有更多的容忍力取增加学习略,并且不需要Dropout。
加入了BN算法到GoogLeNet后,并且配合一些小技巧,在图像分类比赛任务中获得了state-of-the-art的效果。
Batch Normalization的目标是获得激活值的稳定输出,并且我们在非线性单元前加入BN算法,是因为这有可能获得更稳定的输出。相对的,标准化的层输出会导致更稀疏的激活。但是我们没有观察非线性的输出是否更稀疏,BN算法也没有观察。BN算法的微分性质体现包括在可学习的尺度和漂移。
最后,BN算法在未来还有很多值得探讨的地方。