什么是Batch Normalization?
批量标准化(BN)指的是对神经网络每一层的输入进行标准化,目的是为了解决训练过程中输入数据分布漂移——Internal Covariate Shift。
什么是Internal Covariate Shift?
首先机器学习算法都有一个前提假设:数据是独立同分布的。简单来说就是输入空间内的所有变量都服从某一个隐含分布,而模型则是去学习这个分布。
在神经网络的训练过程中,每一层的参数变化都会导致输出与输入的分布发生变化,层层递进,深层神经网络的分布可能会发生剧烈变化。
而这就导致网络训练过程中,模型需要不断调整参数去适应这种变化,极大影响模型收敛速度与性能。
BN算法流程
1.对当前batch数据进行标准化后,再进行线性映射,训练scale与shift参数。因为对batch数据进行标准化也是改变了数据分布,为了消除这一影响,通过训练scale,shift参数来使得网络学到原本的分布。
BN算法流程图
为什么BN算法会work?
直观的理解:
1.BN将激活函数的输出从任意的正态分布拉到均值为,方差为1的标准正态分布,使得输入落到激活函数的敏感区,即较小的变化也会导致loss较大的变化,梯度变大,防止梯度消失的同时也加速模型收敛。
比如如果激活函数是sigmoid函数,在网络的训练过程中,分布会不断靠近激活函数的上下限,即导数不断靠近0.25处,层层传递,最后导致梯度消失,而BN则会将分布拉离上下限。
BN优点总结:
1.防止梯度消失
2.加速模型收敛
3.降低初始化要求,可以设置较大的初始学习率,加快学习。
4.某些情况下,可以提升模型泛化性能,因为BN也可以视为一种正则化的方法。