Batch normalization对数据进行标准化,标准化之后数据的均值为0,方差为1。Batch normalization是解决神经网络中的内部单元的internal covariate shift问题。
Internal covariate shift,由于网络参数变化,输入变化等原因,造成网络内部状态发生偏移,导致激活函数的输入在饱和区域,或者激活函数关于其输入的导数接近于0。比如,sigmoid函数,见下图,在两端的函数值接近于0和1的区域,导数接近0。在BP过程中,由于sigmoid函数的导数近似为0,那么随着误差信号的传播,参数的梯度趋于0,即梯度消失。梯度消失问题严重阻碍了网络的训练,大大提高了训练的时间成本,有时甚至造成网络不可训练。
Batch normalization正是一个用来解决造成梯度消失的方法。Batch normalization的思想:既然激活函数的输入落在了导数近似0的区域,那么可将其强制拉回到激活函数的近似线性区域。方程原理如下。
上述前3个方程将输入x标准化为0均值1方差,但是标准化强行改变了输入的值,即改变了网络提取的特征,降低了网络的表达能力。因此,上述第4个方程通过线性映射,对标准化后的值进行补偿,以保留网络的表达能力。
综上所述,batch normalization通过改变数据的均值和方差来达到解决internal covariate shift的目的。标准化将激活函数的输入映射到了近似线性区域,导数不再趋于0,促进误差信号的反向传播,加快了网络的训练。此外,batch normalization从内部状态分布和网络结构两个视角将网络推向更深层。
参考文献:
Sergey Ioffe, Christian Szegedy. Batch Normalization: ccelerating Deep Network Training by Reducing Internal Covariate Shift