I-概念
梯度消失和梯度爆炸是深度神经网络训练过程中常见的问题。
梯度消失是指在神经网络的反向传播中,由于神经元的梯度值过小或被截断,导致梯度无法有效地传递到前面的层,从而使得前面的层难以更新权重,影响网络的收敛速度和性能。梯度消失通常发生在使用Sigmoid、Tanh等激活函数的深度神经网络中,因为这些函数在输入较大或较小时,梯度会变得非常小,甚至趋近于零。
梯度爆炸是指在神经网络的反向传播中,由于神经元的梯度值过大或爆炸,导致梯度无法有效地传递到前面的层,从而使得前面的层难以更新权重,影响网络的收敛速度和性能。梯度爆炸通常发生在网络中存在较大的权重值或较深的网络结构时。
想要真正了解问题发生的本质,需要明确反向传播过程。
II-反向传播
以下图神经网络为例
这里定义激活函数为sigmoid函数,公式如下
σ
(
z
)
=
1
1
+
e
−
z
\sigma(z)=\frac{1}{1+e^{-z}}
σ(z)=1+e−z1
导数为
σ
‘
(
z
)
=
σ
(
z
)
(
1
−
σ
(
z
)
)
\sigma^`(z)=\sigma(z)(1-\sigma(z))
σ‘(z)=σ(z)(1−σ(z))
损失函数采用均方误差函数,公式如下
L
=
L
o
s
s
t
o
t
a
l
=
∑
1
2
(
t
a
r
g
e
t
i
−
y
i
)
2
L=Loss_{total}=\sum\frac{1}{2}(target_i-y_i)^2
L=Losstotal=∑21(targeti−yi)2
求损失函数对w的偏导,求解过程依靠链式法则
第一个反向传播:求损失函数对w3的偏导
∂
L
∂
w
3
=
∂
L
∂
y
1
∗
∂
y
1
∂
z
‘
∗
∂
z
‘
∂
w
3
\frac{\partial{L}}{\partial{w_3}}= \frac{\partial{L}}{\partial{y_1}}* \frac{\partial{y_1}}{\partial{z^`}}* \frac{\partial{z^`}}{\partial{w_3}}
∂w3∂L=∂y1∂L∗∂z‘∂y1∗∂w3∂z‘
∂ L ∂ y 1 = ∂ 1 2 [ ( t a r g e t 1 − y 1 ) 2 + ( t a r g e t 2 − y 2 ) 2 ] ∂ y 1 = y 1 − t a r g e t 1 = Δ 1 \frac{\partial{L}}{\partial{y_1}}= \frac{\partial{\frac{1}{2}[(target_1-y_1)^2}+(target_2-y_2)^2]}{\partial{y_1}}= y_1-target_1=\Delta_1 ∂y1∂L=∂y1∂21[(target1−y1)2+(target2−y2)2]=y1−target1=Δ1
∂ y 1 ∂ z ‘ = ∂ σ ( z ‘ ) ∂ z ‘ = σ ‘ ( z ‘ ) \frac{\partial{y_1}}{\partial{z^`}}=\frac{\partial{\sigma(z^`)}}{\partial{z^`}}=\sigma^`(z^`) ∂z‘∂y1=∂z‘∂σ(z‘)=σ‘(z‘)
∂ z ‘ ∂ w 3 = a = σ ( z ) \frac{\partial{z^`}}{\partial{w_3}}=a=\sigma(z) ∂w3∂z‘=a=σ(z)
第二个反向传播:求损失函数对w1的偏导
需要注意的是w1的影响不仅来自y1,还来自于y2
∂
L
∂
w
1
=
∂
L
∂
y
1
∗
∂
y
1
∂
z
‘
∗
∂
z
‘
∂
a
∗
∂
a
∂
z
∗
∂
z
∂
w
1
+
∂
L
∂
y
2
∗
∂
y
2
∂
z
‘
‘
∗
∂
z
‘
‘
∂
a
∗
∂
a
∂
z
∗
∂
z
∂
w
1
\frac{\partial{L}}{\partial{w_1}}= \frac{\partial{L}}{\partial{y_1}}* \frac{\partial{y_1}}{\partial{z^`}}* \frac{\partial{z^`}}{\partial{a}}* \frac{\partial{a}}{\partial{z}}* \frac{\partial{z}}{\partial{w_1}}+ \frac{\partial{L}}{\partial{y_2}}* \frac{\partial{y_2}}{\partial{z^{``}}}* \frac{\partial{z^{``}}}{\partial{a}}* \frac{\partial{a}}{\partial{z}}* \frac{\partial{z}}{\partial{w_1}}
∂w1∂L=∂y1∂L∗∂z‘∂y1∗∂a∂z‘∗∂z∂a∗∂w1∂z+∂y2∂L∗∂z‘‘∂y2∗∂a∂z‘‘∗∂z∂a∗∂w1∂z
部分式子已由第一个反向传播求解,其余式子等于
∂
z
‘
∂
a
=
w
3
\frac{\partial{z^`}}{\partial{a}}=w_3
∂a∂z‘=w3
∂ a ∂ z = σ ‘ ( z ) \frac{\partial{a}}{\partial{z}}=\sigma^`(z) ∂z∂a=σ‘(z)
∂ z ∂ w 1 = x 1 \frac{\partial{z}}{\partial{w_1}}=x_1 ∂w1∂z=x1
其他式子同理可得,那么代入右侧所有式子的结果
∂
L
∂
w
1
=
Δ
1
⋅
σ
‘
(
z
‘
)
⋅
w
3
⋅
σ
‘
(
z
)
⋅
x
1
+
Δ
2
⋅
σ
‘
(
z
‘
‘
)
⋅
w
4
⋅
σ
‘
(
z
)
⋅
x
1
\frac{\partial{L}}{\partial{w_1}}=\Delta_1\cdot\sigma^`(z^`)\cdot{w_3}\cdot\sigma^`(z)\cdot{x_1}+\Delta_2\cdot\sigma^`(z^{``})\cdot{w_4}\cdot\sigma^`(z)\cdot{x_1}
∂w1∂L=Δ1⋅σ‘(z‘)⋅w3⋅σ‘(z)⋅x1+Δ2⋅σ‘(z‘‘)⋅w4⋅σ‘(z)⋅x1
III-梯度消失和梯度爆炸
产生原因
从损失函数对w1的偏导公式结果可以看出,对梯度产生影响的是Δ、w、激活函数导数,不包括输入x是因为在多层神经网络中,中间层的某个节点的输入是上一个节点的输出而不是x。
梯度消失
-
激活函数:sigmoid
上图左侧为sigmoid函数图像,右侧为导数图像,可以看出求导后最大值只为0.25。
-
w
通常情况下初始化权重参数w时,用的是0,1正态分布,所以w的最大值为1。
因此,|σ`(z)w|≤0.25,多个小于1的数相乘后会越来越趋近于0,导致靠近输入层的权重偏导几乎为0,也就是几乎不更新,造成梯度消失。
梯度爆炸
与梯度消失不同,也就是当|σ`(z)w|≥1时,多个大于1的数相乘就会导致梯度过大,导致梯度更新幅度特别大,可能会溢出,模型无法收敛。上述例子中sigmoid不可能大于1,但权重可能过大会导致梯度爆炸。
梯度爆炸和梯度消失问题都是因为网络太深,网络权值更新不稳定造成的,本质上是因为梯度反向传播中的连乘效应。
解决方案
-
预训练加微调
此方法来自Hinton在2006年发表的一篇论文,Hinton为了解决梯度的问题,提出采取无监督逐层训练方法,其基本思想是每次训练一层隐节点,训练时将上一层隐节点的输出作为输入,而本层隐节点的输出作为下一层隐节点的输入,此过程就是逐层“预训练”(pre-training);在预训练完成后,再对整个网络进行“微调”(fine-tunning)。Hinton在训练深度信念网络(Deep Belief Networks中,使用了这个方法,在各层预训练完成后,再利用BP算法对整个网络进行训练。此思想相当于是先寻找局部最优,然后整合起来寻找全局最优,此方法有一定的好处,但是目前应用的不是很多了。
-
梯度剪切、正则化
梯度剪切这个方案主要是针对梯度爆炸提出的,其思想是设置一个梯度剪切阈值,然后更新梯度的时候,如果梯度超过这个阈值,那么就将其强制限制在这个范围之内。这可以防止梯度爆炸。
正则化通过对网络权重做正则限制过拟合,通常进行正则化的损失函数公式为
L o s s = ( y − w x ) 2 + α ∣ ∣ w ∣ ∣ 2 Loss=(y-wx)^2+α||w||^2 Loss=(y−wx)2+α∣∣w∣∣2
如果发生梯度爆炸,w的范数就会很大,通过正则化项可以部分限制梯度爆炸发生。 -
改变激活函数
如ReLU在输入大于0时导数恒为1,那么就不存在梯度消失和梯度爆炸问题了,但是ReLU存在dead问题,具体见激活函数笔记
不同的激活函数有各自的优缺点,在选择使用时需要进行考虑
-
残差结构
残差结构核心原理是将输入特征和参考特征拼接在一起进行训练,这样进行求导时就总会有一个1在,避免梯度消失
∂ l o s s ∂ x l = ∂ l o s s ∂ x L ∗ ∂ x L ∂ x l = ∂ l o s s ∂ x L ∗ ( 1 + ∂ ∂ x L ∑ i − l L F ( x i , W i ) ) \frac{\partial{loss}}{\partial{x_l}}=\frac{\partial{loss}}{\partial{x_L}}*\frac{\partial{x_L}}{\partial{x_l}}= \frac{\partial{loss}}{\partial{x_L}}*(1+\frac{\partial}{\partial{x_L}}\sum_{i-l}^LF(x_i,W_i)) ∂xl∂loss=∂xL∂loss∗∂xl∂xL=∂xL∂loss∗(1+∂xL∂i−l∑LF(xi,Wi))
式子中∂loss/∂xL表示的是损失函数到达第L层的梯度,1表示这种短路机制可以无损的传播梯度,另一项残差梯度需要经过有weights的层,残差梯度不会巧合的全部为-1,因此基本可以避免梯度消失问题。 -
Batch Normalization
Batch Normalization对每个小批量样本进行归一化操作,将输入数据规范化到均值为0、方差为1的分布,从而使得网络的参数更容易学习和调整,减少了层间的协变量偏移,使得梯度的传播更加稳定,避免参数值过大或过小,从而缓解了梯度消失和梯度爆炸的问题。