Batch Normalization原理解析
前言
本文章是自己参考一些书籍和博客整理的一些Batch Normalization相关资料,通篇是基于自己的理解进行的整理,以作为日后参考使用。参考资料在文后贴出。
Batch Normalization可以用于解决梯度消失和梯度爆炸问题,也包括原论文里提到的内部协方差转移(Internal Covariate Shift),所以本文章先整理了一些梯度消失和梯度爆炸以及内部协方差转移出现的原理,然后再进行Batch Normalization原理的解析。
1.1梯度消失和梯度爆炸
在一些论文(比如resnet那篇)和技术书籍中,Batch Normalization被提到可以用于解决梯度消失和梯度爆炸,在此参考《深入浅出Pytorch》这本书,给出梯度消失和梯度爆炸出现的原理。
其中
h
j
\mathbf{h}_{j}
hj为第
j
j
j层神经元的输入,
W
j
\mathbf{W}_j
Wj为第
j
j
j层神经元的权重,而
h
j
+
1
\mathbf{h}_{j+1}
hj+1为该层的输出,即作为下层的输入,理论上
h
j
+
1
=
W
j
h
j
\mathbf{h}_{j+1}=\mathbf{W}_j\mathbf{h}_{j}
hj+1=Wjhj,加上激活函数后
h
j
+
1
=
f
j
(
W
j
h
j
)
\mathbf{h}_{j+1} = f_j(\mathbf{W}_j\mathbf{h}_j )
hj+1=fj(Wjhj)。
根据微积分里的链式法则,
f
(
x
)
f(\mathbf{x})
f(x)对
x
\mathbf{x}
x的求导为:
∂
f
(
x
)
∂
x
=
∂
f
(
x
)
∂
y
∂
y
∂
x
\frac{\partial f(\mathbf{x})}{\partial \mathbf{x}}= \frac{\partial f(\mathbf{x})}{\partial \mathbf{y}} \frac{\partial \mathbf{y}}{\partial \mathbf{x}}
∂x∂f(x)=∂y∂f(x)∂x∂y
我们假设最后的损失函数是
L
=
f
n
(
h
n
)
L = f_n(\mathbf{h}_n)
L=fn(hn),是输出层神经元的函数,对两边求导,根据链式法则:
∂
L
∂
W
j
=
∂
L
∂
h
j
+
1
∂
h
j
+
1
∂
W
j
=
(
∂
L
∂
h
j
+
1
⊙
∂
f
j
(
W
j
h
j
)
∂
W
j
h
j
)
h
j
T
\frac{\partial L}{\partial \mathbf{W}_{j}} = \frac{\partial L}{\partial \mathbf{h}_{j+1}} \frac{\partial \mathbf{h}_{j+1}}{\partial \mathbf{W}_{j}} = \left(\frac{\partial L}{\partial \mathbf{h}_{j+1}} \odot \frac{\partial f_j(\mathbf{W}_j\mathbf{h}_j )}{\partial \mathbf{W}_j\mathbf{h}_j} \right)\mathbf{h}_j^T
∂Wj∂L=∂hj+1∂L∂Wj∂hj+1=(∂hj+1∂L⊙∂Wjhj∂fj(Wjhj))hjT
∂
L
∂
h
j
=
∂
L
∂
h
j
+
1
∂
h
j
+
1
∂
h
j
=
W
j
T
(
∂
L
∂
h
j
+
1
⊙
∂
f
j
(
W
j
h
j
)
∂
W
j
h
j
)
\frac{\partial L}{\partial \mathbf{h}_{j}} = \frac{\partial L}{\partial \mathbf{h}_{j+1}} \frac{\partial \mathbf{h}_{j+1}}{\partial \mathbf{h}_{j}} = \mathbf{W}_j^T \left(\frac{\partial L}{\partial \mathbf{h}_{j+1}} \odot \frac{\partial f_j(\mathbf{W}_j\mathbf{h}_j )}{\partial \mathbf{W}_j\mathbf{h}_j} \right)
∂hj∂L=∂hj+1∂L∂hj∂hj+1=WjT(∂hj+1∂L⊙∂Wjhj∂fj(Wjhj))
其中
∂
L
∂
h
j
\frac{\partial L}{\partial \mathbf{h}_{j}}
∂hj∂L式可以看成损失函数对数据的导数,即数据梯度;而
∂
L
∂
W
j
\frac{\partial L}{\partial \mathbf{W}_{j}}
∂Wj∂L为损失函数对权重的导数,即权重梯度。从公式二可以看出,数据梯度和权重有关,权重梯度和数据有关,而前一层的数据梯度和权重梯度都和后一层的数据梯度有关。
接下来就可以解释梯度消失和梯度爆炸了:
- 梯度消失:当构建的神经网络非常深时,不同的层学习的速度差异很大,表现为网络中靠近输出的层学习的情况很好,靠近输入的层学习的很慢。造成该问题的原因有很多,比如权重初始化不当,或者激活函数使用不当,拿激活函数举例更容易理解。
若使用Sigmoid或Tanh作为激活函数,它们的特点为梯度小于1。那意味反向传播时每次往下一级传播时激活函数对数据的导数都小于1,即 ∂ f ( h j ) ∂ h j \frac{\partial f(\mathbf{h}_{j})}{\partial \mathbf{h}_{j}} ∂hj∂f(hj)小于1。而每次往前一级传播数据梯度都会乘以 ∂ f ( h j ) ∂ h j \frac{\partial f(\mathbf{h}_{j})}{\partial \mathbf{h}_{j}} ∂hj∂f(hj),所以传播得越深,最后的数据梯度越小,对应的权重梯度也越小,就造成了梯度消失。所以在构建网络时,我们通常会用ReLU函数作为激活函数,因为它梯度为1。若权重初始化不当,比如一些权重过小,也会造成该问题。 - 梯度爆炸:如果权重初始化把一些权重取值太大,那么在反向传播时,每向前传播一级,数据梯度都会变大,对应的权重梯度也会叠加变大,所以造成靠近输入层的权重梯度过大。
综上,权重初始化和激活函数是造成梯度消失和梯度爆炸的主要原因,所以权重初始化时尽量将权重初始化值分布在1附近。
2.1内部协方差转移
内部协方差转移是在Batch Normalization这篇论文里提到的。上文说到深度神经网络涉及到很多层的叠加,每一层的参数更新会导致上层的输入数据分布发生变化,通过层层叠加,高层的输入分布变化会非常剧烈,这就使得高层需要不断去重新适应底层的参数更新。
也就是说我们输入的数据,经过网络的每一层都会进行一次非线性变换,一直到最后一层,此时的输入数据的分布已经被改变了,但ground truth是不会变的,这就造成了网络中靠后的神经元需要不断适应更新参数适应新的数据分布,并且每一层的更新都会影响下一层的变化,所以在优化器参数设置上需要非常谨慎。
3.1Batch Normalization原理
下面是关于Batch Normalization原理的分析,为了解决内部协方差转移,必须让网络的每一层的输入都满足独立同分布才行,而这就是Batch Normalization的作法。
拿卷积神经网络举例。假设我们网络的某一层有
k
k
k个神经元,它的前一层有
j
j
j个神经元,则第
j
j
j层的输出即为[B,j,H1,W1],其中B为Batch_Size,j为该层输出的通道数。第
j
j
j层的输出作为输入传递给第
k
k
k层,而第
k
k
k层有
k
k
k个神经元,相当于该层的输出通道数为
k
k
k个,即第
k
k
k层的每一个神经元的权重维度为[j,S,S],S为卷积核大小,每个神经元的权重与输入[B,j,H1,W1]进行卷积操作得到维度[B,1,H2,W2],而一共有
k
k
k个这样的神经元,所以第
k
k
k层输出的整体维度为[B,k,H2,W2]。可以看下图加深理解:
上图第
j
j
j层输出为[B,4,H1,W1],作为输入传递给第
k
k
k层,第
k
k
k层有两个神经元,每个神经元的权重维度为[4,S,S],但个神经元的权重与输入作卷积操作,得到的结果为[B,1,H2,W2],那么两个神经元的结果进行cat操作,得到整体结果[B,2,H2,W2]。
Batch Normalization就是作用在第
k
k
k层的输出上的,继续假设第
k
k
k层有
k
k
k个神经元,Batch_Size 为
m
m
m,表示
m
m
m个数据,所以第
k
k
k层输出的维度为
[
m
,
k
,
H
,
W
]
[m,k,H,W]
[m,k,H,W],相当于一共
m
m
m个数据,每个数据有
k
k
k个通道,每个通道为
[
H
,
W
]
[H,W]
[H,W]的矩阵,而Batch Normalization就是对
m
m
m个数据的每一个维度作正则化,如下图:
以上是我们使用BN时把它添加的位置,一般一个Conv层后就要接一个BN层,然后再接ReLU等激活层。下面再看一下BN的具体公式。
继续用上面提到的例子,
m
m
m个数据经过第
k
k
k层得到了维度为
[
m
,
k
,
H
,
W
]
[m,k,H,W]
[m,k,H,W]的输出,即
m
m
m个数据,每个数据有
k
k
k个通道,每个通道为
[
H
,
W
]
[H,W]
[H,W]的矩阵。对该输出进行Batch Normalization,就是把
m
m
m个数据的每一个通道提出来,进行正则化:
μ
1
=
1
m
∑
i
=
1
m
x
1
i
σ
1
2
=
1
m
∑
i
=
1
m
(
x
1
i
−
μ
1
)
2
x
^
1
←
x
1
i
−
μ
1
σ
1
2
+
ϵ
y
1
←
γ
1
x
1
^
+
β
1
≡
B
N
γ
1
,
β
1
(
x
1
)
\mu_{1} = \frac{1}{m} \sum_{i=1}^{m} x_{1i} \\ \sigma_{1}^{2} = \frac{1}{m} \sum_{i=1}^{m}\left(x_{1i}-\mu_{1}\right)^{2} \\ \hat{x}_1 \leftarrow \frac{x_{1i}-\mu_{1}}{\sqrt{\sigma_{1}^{2}+\epsilon}} \\ y_{1} \leftarrow \gamma_1 \hat{x_{1}}+\beta_1 \equiv B N_{\gamma_1, \beta_1}\left(x_{1}\right)
μ1=m1i=1∑mx1iσ12=m1i=1∑m(x1i−μ1)2x^1←σ12+ϵx1i−μ1y1←γ1x1^+β1≡BNγ1,β1(x1)
其中
x
1
x_1
x1表示整个Batch的第一个通道,
x
1
i
x_{1i}
x1i表示第
i
i
i个数据的第一个通道。该操作可以分为两步:
- Standardization:首先对 m m m个 m m m进行 Standardization,得到 zero mean unit variance的分布 x ^ 1 \hat{x}_1 x^1;
- scale and shift:然后再对 x ^ 1 \hat{x}_1 x^1进行scale and shift,缩放并平移到新的分布 y 1 y_1 y1,具有新的均值方差 γ 1 \gamma_1 γ1。
γ 1 \gamma_1 γ1和 β 1 \beta_1 β1为待学习的scale和shift参数,用于控制 y 1 y_1 y1的方差和均值。