Batch normalization
要知道,数据分布太多元化并不利于DNN的训练。
BN的提出就是为了狙击这个问题。
出于whitening的代价太大了,BN采用了mini-batch进行normalization的方案。
用一个 N × C × H × W N\times C \times H \times W N×C×H×W的数据举栗:
在大小为 M ( M < N ) M(M<N) M(M<N)的batch之中,基于信道 C C C求均值、方差。
μ B ← 1 m ∑ i = 1 m x i \mu_B \leftarrow \frac{1}{m}\sum_{i=1}^mx_i μB←m1∑i=1mxi
σ B 2 ← 1 m ∑ i = 1 m ( x i − μ B ) 2 \sigma_B^2 \leftarrow \frac{1}{m}\sum_{i=1}^m(x_i-\mu_B)^2 σB2←m1∑i=1m(xi−μB)2
x ^ i ← x i − μ B σ B 2 + ϵ \hat x_i \leftarrow \frac{x_i-\mu_B}{\sqrt{\sigma_B^2+\epsilon}} x^i←σB2+ϵxi−μB
y i ← γ x ^ i + β y_i \leftarrow \gamma\hat x_i +\beta yi←γx^i+β
x ^ i \hat x_i x^i就是减去均值,除去方差根的数据。
然后通过 γ , β \gamma,\beta γ,β 再学习得到新的数据 y i y_i yi。
(没有 γ , β \gamma,\beta γ,β的话,所有数据都是一个分布,这样每层传递出的信息就被切断了。 γ , β \gamma,\beta γ,β的加入就是为了在「弥补信息损失」和「加速收敛」之间找到一个平衡点)
不过这是在training阶段,在inference阶段,是用training时候得到的均值和方差(经过小小处理)来对test-data进行归一化的(单纯线性变换一下了,很快):
x ^ = x − E [ x ] V a r [ x ] + ϵ \hat x = \frac{x-E[x]}{\sqrt{Var[x]+\epsilon}} x^=Var[x]+ϵx−E[x]
V a r [ x ] = m m − 1 E B [ σ B 2 ] , E [ x ] = E [ μ B ] Var[x]=\frac{m}{m-1}E_B[ \sigma_B^2],E[x]=E[\mu_B] Var[x]=m−1mEB[σB2],E[x]=E[μB]
⬆️ V a r [ x ] , E [ x ] Var[x],E[x] Var[x],E[x]是整合了training期间各个 μ B , σ B 2 \mu_B,\sigma_B^2 μB,σB2得到的Ծ ̮ Ծ。
在BN的论文中,用了 s i g m o i d sigmoid sigmoid的激活函数,此举将数据大致固定在 s i g m o i d sigmoid sigmoid的线形区域,起到了加速收敛,正则化的作用。(why?)
个人理解:
在 s i g m o i d sigmoid sigmoid的饱和区,bp的时候, s i g m o i d ′ ( ) sigmoid'() sigmoid′()是非常小的(今天定个小目标,先给更新一丢丢),在线形区收敛速度自然会快。
至于正则化,也是因为线性区的缘故:线性变换+线性变换+线性变换=线性变换(✧(σ๑˃̶̀ꇴ˂̶́)σ ),这样子出来的曲线就会不会东扭西扭,扭成麻花导致过拟合啦。
DNN在反向传播过程中有可能会出现梯度弥散,BN对这个则有奇效。
首先梯度弥散是如何造成的呢?
一个原因是激活函数可能陷入饱和区导致grad太小啦(上面说到的 s i g m o i d ′ ( ) sigmoid'() sigmoid′()就是栗子)。
还有一个就是 W 1 W 2 W 3 … W1W2W3… W1W2W3…,如果 W 1 < 1 , W 2 < 1 , W 3 < 1 … W1<1,W2<1,W3<1… W1<1,W2<1,W3<1…,相乘直接就vanish咯。
BN应对第一个原因就是弯的给他掰直!(斜眼笑)
至于应对第二个原因, ∂ y ∂ x ^ = γ \frac{\partial y}{\partial\hat x}=\gamma ∂x^∂y=γ ,不过我认为 γ \gamma γ没有这么猛吧(小声BB,待我跑个程序验证一下)。
应该是如Physcal所说的:
直观上来说,对于一个网络层:
I、 X X X大点, W W W肯定要小点。
II、 X X X小点, W W W肯定要大点。
违反这两条,会让激活值处于函数边界,从而被自然选择淘汰掉(有点遗传算法的味道)。
这是经典的大拇指规则(Rule of Thumb),由无数前辈的实验得到,似乎已经成了共识。
normalize之后,各层的XX遭到了压制,并且向高斯分布中心进行数值收缩。
进而,由 X X X影响到了 W W W, W W W也向高斯分布中心进行数值收缩。
这样, W 1 W 2 W 3 … … . W n W1W2W3…….Wn W1W2W3…….Wn的衰减将会得到可观的减缓。
这大概是Batch Normalization可以减轻使用ReLU的Gradient Vanish的直接原因。
这一点也可以拿来解释Dropout的正则化作用,Dropout会暂时屏蔽一些单元,然后对保留单元做一个放大(比如说,屏蔽了1/3的单元,就将保留单元放大3/2), X X X的变大会引起 W ( a l i v e ) W(alive) W(alive)的减小,从而压制了 W W W。
不过梯度衰减这点,inception V3那篇论文说可能人们把它misplace了。。Ծ‸Ծ
Learning rate
Learning rate 是可以更大了,但也不是超级大,不然还是会摆过头的。
learning rate可以更大的原因个人认为是:因为BN后数据处于一个更为健康的状态(身体好,孩子吃饭又快又香),可以用更大的学习率,更少的梯度更新次数。这点知乎网友纳米酱说的比较nice。
Reference:
Ioffe S, Szegedy C. Batch normalization: Accelerating deep network training by reducing internal covariate shift[J]. arXiv preprint arXiv:1502.03167, 2015.