前言
梯度爆炸和梯度消失问题都是因为网络太深,网络权值更新不稳定造成的,本质上是因为梯度反向传播中的连乘效应。
前向传播:
z
1
=
w
1
X
+
b
1
,
a
1
=
σ
(
z
1
)
z
2
=
w
2
a
1
+
b
2
,
a
2
=
σ
(
z
2
)
.
.
.
z
n
=
w
n
a
n
−
1
+
b
n
,
a
n
=
σ
(
z
n
)
\begin{aligned} z_1&=w_1X+b_1,a_1=\sigma (z_1)\\ z_2&=w_2a_1+b_2,a_2=\sigma(z_2)\\ ...\\ z_n&=w_na_{n-1+b_n},a_n=\sigma(z_n)\\ \end{aligned}
z1z2...zn=w1X+b1,a1=σ(z1)=w2a1+b2,a2=σ(z2)=wnan−1+bn,an=σ(zn)
则反向传播:
α
l
o
s
s
α
w
1
=
α
l
o
s
s
α
a
n
α
a
n
α
z
n
α
z
n
α
a
n
−
1
α
a
n
−
1
α
z
n
−
1
α
z
n
−
1
α
a
n
−
2
α
a
n
−
2
α
z
n
−
2
.
.
.
α
a
1
α
z
1
α
z
1
α
w
1
=
α
l
o
s
s
s
α
a
n
⋅
σ
′
(
z
n
)
w
n
⋅
σ
′
(
z
n
−
1
)
w
n
−
1
⋅
.
.
.
⋅
σ
′
(
z
1
)
X
\begin{aligned} \frac{\alpha loss}{\alpha w_1} &=\frac{\alpha loss}{\alpha a_n}\frac{\alpha a_n}{\alpha z_n}\frac{\alpha z_n}{\alpha a_{n-1}}\frac{\alpha a_{n-1}}{\alpha z_{n-1}}\frac{\alpha z_{n-1}}{\alpha a_{n-2}}\frac{\alpha a_{n-2}}{\alpha z_{n-2}}...\frac{\alpha a_1}{\alpha z_1}\frac{\alpha z_1}{\alpha w_1}\\ &=\frac{\alpha losss}{\alpha a_n}·\sigma'(z_n)w_n·\sigma'(z_{n-1})w_{n-1}·...·\sigma'(z_1)X \end{aligned}
αw1αloss=αanαlossαznαanαan−1αznαzn−1αan−1αan−2αzn−1αzn−2αan−2...αz1αa1αw1αz1=αanαlosss⋅σ′(zn)wn⋅σ′(zn−1)wn−1⋅...⋅σ′(z1)X
-
梯度消失:与激活函数的导数 σ ′ ( x ) \sigma^{'}(x) σ′(x)有关。
假如 σ \sigma σ为sigmoid激活函数,而sigmoid的导数范围是[0,0.25],"链式法则"的累乘会导致梯度趋于0. -
梯度爆炸:与权重有关,即 ∣ σ ′ ( z ) w ∣ > 1 |\sigma'(z) w|>1 ∣σ′(z)w∣>1。
链式法则还与 ∣ σ ′ ( z ) w ∣ |\sigma'(z) w| ∣σ′(z)w∣有关,如果该值>1,"链式法则"累乘后会导致梯度趋于非常大的值.
梯度消失
与梯度太小有关。表现为只在后层学习,浅层不学习,浅层梯度基本无,权重改变量小,收敛慢,训练速度慢。
原因:
- 采用了不适合的激活函数,导致链式法则累乘时被0影响。
- 模型在训练的过程中,会不断调整数据分布,有可能接近激活函数饱和区,此时的导数很小,难以调整权重。
解决办法:
- 使用BN,将数据分布归一化。
- 预训练,微调。
- 使用relu等激活函数。
- 使用残差结构。
- LSTM。
- 正则化。
梯度爆炸
与链式法则中的权重有关。可能导致权重NAN。
原因:
- 若初始化权重太大,累乘后会爆炸。
- 梯度>1。
解决办法:
- 注意权重初始化。
- 梯度剪裁。
- BN。
- 预训练,微调。
RNN为何会梯度消失/爆炸?
首先看RNN计算流程,简设3个timestep:
前向传播:
S
1
=
W
x
X
1
+
W
s
S
0
+
b
1
S_1=W_xX_1+W_sS_0+b1
S1=WxX1+WsS0+b1,
O
1
=
W
o
S
1
+
b
2
O_1=W_oS_1+b2
O1=WoS1+b2。
S
2
=
W
x
X
2
+
W
s
S
1
+
b
1
S_2=W_xX_2+W_sS_1+b1
S2=WxX2+WsS1+b1,
O
2
=
W
o
S
2
+
b
2
O_2=W_oS_2+b2
O2=WoS2+b2。
S
3
=
W
x
X
3
+
W
s
S
2
+
b
1
S_3=W_xX_3+W_sS_2+b1
S3=WxX3+WsS2+b1,
O
3
=
W
o
S
3
+
b
2
O_3=W_oS_3+b2
O3=WoS3+b2。
此刻的损失函数: l o s s 3 = 1 2 ( Y 3 − O 3 ) 2 loss_3=\frac{1}{2}(Y_3-O_3)^2 loss3=21(Y3−O3)2。
反向传播:
需要对 W o W_o Wo, W s W_s Ws, W x W_x Wx求导,其中对 W s W_s Ws和 W x W_x Wx求导是同理的。
(1) δ l o s s 3 δ W o = δ l o s s 3 δ O 3 δ O 3 δ W o \frac{\delta loss_3}{\delta W_o}=\frac{\delta loss_3}{\delta O_3}\frac{\delta O_3}{\delta W_o} δWoδloss3=δO3δloss3δWoδO3
可以看出网络加深对于 W o W_o Wo无影响。
(2) δ l o s s 3 δ W s = δ l o s s 3 δ O 3 δ O 3 δ S 3 δ S 3 δ W s + δ l o s s 3 δ O 3 δ O 3 δ S 3 δ S 3 δ S 2 δ S 2 δ W s + δ l o s s 3 δ O 3 δ O 3 δ S 3 δ S 3 δ S 2 δ S 2 δ S 1 δ S 1 δ W s \frac{\delta loss_3}{\delta W_s}=\frac{\delta loss_3}{\delta O_3}\frac{\delta O_3}{\delta S_3}\frac{\delta S_3}{\delta W_s}+\frac{\delta loss_3}{\delta O_3}\frac{\delta O_3}{\delta S_3}\frac{\delta S_3}{\delta S_2}\frac{\delta S_2}{\delta W_s}+\frac{\delta loss_3}{\delta O_3}\frac{\delta O_3}{\delta S_3}\frac{\delta S_3}{\delta S_2}\frac{\delta S_2}{\delta S_1}\frac{\delta S_1}{\delta W_s} δWsδloss3=δO3δloss3δS3δO3δWsδS3+δO3δloss3δS3δO3δS2δS3δWsδS2+δO3δloss3δS3δO3δS2δS3δS1δS2δWsδS1。
可以简写为:
δ l o s s t δ W s = ∑ k = 0 t δ l o s s t δ O t δ O t δ S t ∏ j = k + 1 t ( δ S j δ S j − 1 ) δ S k δ W x \frac{\delta loss_t}{\delta W_s}=\sum_{k=0}^t\frac{\delta loss_t}{\delta O_t}\frac{\delta O_t}{\delta S_t}\prod_{j=k+1}^t(\frac{\delta S_j}{\delta S_{j-1}})\frac{\delta S_k}{\delta W_x} δWsδlosst=∑k=0tδOtδlosstδStδOt∏j=k+1t(δSj−1δSj)δWxδSk。
其中连乘的 ∏ j = k + 1 t ( δ S j δ S j − 1 ) \prod_{j=k+1}^t(\frac{\delta S_j}{\delta S_{j-1}}) ∏j=k+1t(δSj−1δSj)是导致梯度爆炸和消失的问题所在。
RNN梯度与其他网络梯度的区别
- MLP/CNN 中不同的层有不同的参数,各是各的梯度;而 RNN 中同样的权重在各个时间步共享,最终的梯度= 各个时间步的梯度的和。
- RNN 中总的梯度是不会消失的。即便梯度越传越弱,那也只是远距离的梯度消失,由于近距离的梯度不会消失,所有梯度之和便不会消失。RNN 所谓梯度消失的真正含义是,梯度被近距离梯度主导,导致模型难以学到远距离的依赖关系。
LSTM如何缓解梯度消失/爆炸?
LSTM介绍
- 遗忘门
- 可求得:
f
t
=
σ
(
W
f
⋅
[
h
t
−
1
,
x
t
]
+
b
f
)
f_t=\sigma (W_f·[h_{t-1},x_t]+b_f)
ft=σ(Wf⋅[ht−1,xt]+bf).
- 可求得:
f
t
=
σ
(
W
f
⋅
[
h
t
−
1
,
x
t
]
+
b
f
)
f_t=\sigma (W_f·[h_{t-1},x_t]+b_f)
ft=σ(Wf⋅[ht−1,xt]+bf).
- 输入门
可求得:- i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t=\sigma (W_i·[h_{t-1},x_t]+b_i) it=σ(Wi⋅[ht−1,xt]+bi).
-
C
^
t
=
t
a
n
h
(
W
C
⋅
[
h
t
−
1
,
x
t
]
+
b
C
)
\hat C_t=tanh (W_C·[h_{t-1},x_t]+b_C)
C^t=tanh(WC⋅[ht−1,xt]+bC).
-
C
t
=
f
t
⋅
C
t
−
1
+
i
t
⋅
C
^
t
C_t=f_t·C_{t-1}+i_t·\hat C_t
Ct=ft⋅Ct−1+it⋅C^t
- 输出门
可求得:-
O
t
=
σ
(
W
o
⋅
[
h
t
−
1
,
x
t
]
+
b
o
)
O_t=\sigma (W_o·[h_{t-1},x_t]+b_o)
Ot=σ(Wo⋅[ht−1,xt]+bo).
-
h
t
=
O
t
⋅
t
a
n
h
(
C
t
)
h_t=O_t·tanh(C_t)
ht=Ot⋅tanh(Ct).
-
O
t
=
σ
(
W
o
⋅
[
h
t
−
1
,
x
t
]
+
b
o
)
O_t=\sigma (W_o·[h_{t-1},x_t]+b_o)
Ot=σ(Wo⋅[ht−1,xt]+bo).
LSTM可以解决梯度消失,缓解梯度爆炸
整理可得公式:
- f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t=\sigma (W_f·[h_{t-1},x_t]+b_f) ft=σ(Wf⋅[ht−1,xt]+bf).
- i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t=\sigma (W_i·[h_{t-1},x_t]+b_i) it=σ(Wi⋅[ht−1,xt]+bi).
- C ^ t = t a n h ( W c ⋅ [ h t − 1 , x t ] + b c ) \hat C_t=tanh (W_c·[h_{t-1},x_t]+b_c) C^t=tanh(Wc⋅[ht−1,xt]+bc).
- C t = f t ⋅ C t − 1 + i t ⋅ C ^ t C_t=f_t·C_{t-1}+i_t·\hat C_t Ct=ft⋅Ct−1+it⋅C^t
- O t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) O_t=\sigma (W_o·[h_{t-1},x_t]+b_o) Ot=σ(Wo⋅[ht−1,xt]+bo).
- h t = O t ⋅ t a n h ( C t ) h_t=O_t·tanh(C_t) ht=Ot⋅tanh(Ct).
-
LSTM 中梯度的传播有很多条路径, C t − 1 → C t = f t ⋅ c t − 1 + i t ⋅ c ^ t C_{t-1} \rightarrow C_t=f_t·c_{t-1}+i_t·\hat c_t Ct−1→Ct=ft⋅ct−1+it⋅c^t这条路径上只有逐元素相乘和相加的操作,梯度流最稳定;但是其他路径(例如 C t − 1 → h t − 1 → i t → c t C_{t-1} \rightarrow h_{t-1} \rightarrow i_t \rightarrow c_t Ct−1→ht−1→it→ct)上梯度流与普通 RNN 类似,照样会发生相同的权重矩阵反复连乘。根据上式可以看出 C t C_t Ct公式与 h t h_t ht, i t i_t it, C ^ t \hat C_t C^t, C t − 1 C_{t-1} Ct−1有关,则可以得出:
δ C ( k ) δ C ( k − 1 ) = δ C ( k ) δ f ( k ) δ f ( k ) δ h ( k − 1 ) δ h ( k − 1 ) δ C ( k − 1 ) [ h t 公 式 ] + δ C ( k ) δ i ( k ) δ i ( k ) δ h ( k − 1 ) δ h ( k − 1 ) δ C ( k − 1 ) [ i t 公 式 ] + δ C ( k ) δ C ^ ( k ) δ C ^ ( k ) δ h ( k − 1 ) δ h ( k − 1 ) δ C ( k − 1 ) [ C ^ t 公 式 ] + δ C ( k ) δ C ( k − 1 ) [ C t 公 式 ] = C t − 1 ( σ ′ ⋅ W f ) ( o t ⋅ t a n h ′ ) + C ^ t ( σ ′ ⋅ W i ) ( o t ⋅ t a n h ′ ) + i t ( t a n h ′ ⋅ W c ) ( o t ⋅ t a n h ′ ) + f t \begin{aligned} \frac{\delta C^{(k)}}{\delta C^{(k-1)}} &=\frac{\delta C^{(k)}}{\delta f^{(k)}}\frac{\delta f^{(k)}}{\delta h^{(k-1)}}\frac{\delta h^{(k-1)}}{\delta C^{(k-1)}}[h_t公式]\\ &+\frac{\delta C^{(k)}}{\delta i^{(k)}}\frac{\delta i^{(k)}}{\delta h^{(k-1)}}\frac{\delta h^{(k-1)}}{\delta C^{(k-1)}}[i_t公式]\\ &+\frac{\delta C^{(k)}}{\delta \hat C^{(k)}}\frac{\delta \hat C^{(k)}}{\delta h^{(k-1)}}\frac{\delta h^{(k-1)}}{\delta C^{(k-1)}}[\hat C_t公式]\\ &+\frac{\delta C^{(k)}}{\delta C^{(k-1)}} [C_t公式]\\ &=C^{t-1}(\sigma'·W_f)(o^t·tanh')\\ &+\hat C^{t}(\sigma'·W_i)(o^t·tanh')\\ &+i^{t}(tanh'·W_c)(o^t·tanh')\\ &+f_t \end{aligned} δC(k−1)δC(k)=δf(k)δC(k)δh(k−1)δf(k)δC(k−1)δh(k−1)[ht公式]+δi(k)δC(k)δh(k−1)δi(k)δC(k−1)δh(k−1)[it公式]+δC^(k)δC(k)δh(k−1)δC^(k)δC(k−1)δh(k−1)[C^t公式]+δC(k−1)δC(k)[Ct公式]=Ct−1(σ′⋅Wf)(ot⋅tanh′)+C^t(σ′⋅Wi)(ot⋅tanh′)+it(tanh′⋅Wc)(ot⋅tanh′)+ft
因此RNN的问题 ∏ j = k t \prod_{j=k}^t ∏j=kt在LSTM中等价于 ( f k ⋅ f k + 1 ⋅ f 2 ⋅ . . . ⋅ f t ) + o t h e r (f^{k}·f^{k+1}·f^{2}·...·f^{t})+other (fk⋅fk+1⋅f2⋅...⋅ft)+other -
正常梯度 + 消失梯度 = 正常梯度,总的远距离梯度就不会消失,因此 LSTM 可以解决梯度消失。
- 可自主选择[0,1]之间,当遗忘门接近 1时(例如模型初始化时会把 forget bias 设置成较大的正数,让遗忘门饱和),这时候远距离梯度不消失;
- 当遗忘门接近 0时,但这时模型是故意阻断梯度流的(例如情感分析任务中有一条样本 “A,但是 B”,模型读到“但是”后选择把遗忘门设置成 0,遗忘掉内容 A,这是合理的)。
-
正常梯度 + 爆炸梯度 = 爆炸梯度,因此 LSTM 仍然有可能发生梯度爆炸。不过,由于 LSTM 和普通 RNN 相比多经过了很多次激活函数(导数都小于 1),因此 LSTM 发生梯度爆炸的频率要低得多。
参考
https://zhuanlan.zhihu.com/p/25631496
https://www.cnblogs.com/bonelee/p/10475453.html
https://www.zhihu.com/question/34878706/answer/665429718