之前看过,现在突然想不起,真的是好记性不如烂笔头,希望大家在看的时候能够拿笔和纸跟着推导一遍,加深理解。
转自:https://zhuanlan.zhihu.com/p/28687529 https://zhuanlan.zhihu.com/p/28749444
1.RNN梯度弥散和爆炸的原因
经典的RNN结构如下图所示:
假设我们的时间序列只有三段,
S
0
S_0
S0为给定值,神经元没有激活函数,则RNN最简单的前向传播过程如下:
假设在t=3时刻,损失函数为
L
3
=
1
2
(
Y
3
−
O
3
)
2
L_3=\frac{1}{2}(Y_3-O_3)^2
L3=21(Y3−O3)2
则对于一次训练任务的损失函数为 L = ∑ t = 1 T L t L=\sum_{t=1}^{T}L_t L=∑t=1TLt,即每一时刻损失值的累加。
使用随机梯度下降法训练RNN其实就是对 W x W_x Wx 、 W s W_s Ws 、 W 0 W_0 W0 以及 b 1 b_1 b1 b 2 b_2 b2 求偏导,并不断调整它们以使L尽可能达到最小的过程。
现在假设我们我们的时间序列只有三段,t1,t2,t3。
我们只对t3时刻的 [公式] 求偏导(其他时刻类似):
可以看出对于
W
0
W_0
W0 求偏导并没有长期依赖,但是对于
W
x
W_x
Wx、
W
s
W_s
Ws求偏导,会随着时间序列产生长期依赖。因为
S
t
S_t
St随着时间序列向前传播,而
S
t
S_t
St又是
W
x
W_x
Wx、
W
s
W_s
Ws的函数。
根据上述求偏导的过程,我们可以得出任意时刻对
W
x
W_x
Wx、
W
s
W_s
Ws求偏导的公式:
任意时刻对 W s W_s Ws 求偏导的公式同上。
如果加上激活函数,
S
j
=
t
a
n
h
(
W
x
X
j
+
W
s
S
j
−
1
+
b
1
)
S_j=tanh(W_xX_j+W_sS_{j-1}+b_1)
Sj=tanh(WxXj+WsSj−1+b1) ,
则
∏
j
=
k
+
1
t
∂
S
j
∂
S
j
−
1
=
∏
j
=
k
+
1
t
t
a
n
h
′
W
s
\prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}}=\prod_{j=k+1}^{t}tanh^{'}W_s
∏j=k+1t∂Sj−1∂Sj=∏j=k+1ttanh′Ws
激活函数tanh和它的导数图像如下。
由上图可以看出 t a n h ′ ≤ 1 tanh^{'} \leq1 tanh′≤1,对于训练过程大部分情况下tanh的导数是小于1的,因为很少情况下会出现 W x X j + W s S j − 1 + b 1 = 0 W_xX_j+W_sS_{j-1}+b_1=0 WxXj+WsSj−1+b1=0,如果 W s W_s Ws 也是一个大于0小于1的值,则当t很大时 ∏ j = k + 1 t t a n h ′ W s \prod_{j=k+1}^{t}tanh^{'}W_s ∏j=k+1ttanh′Ws,就会趋近于0,和 0.01^{50} 趋近与0是一个道理。同理当 W s W_s Ws很大时 ∏ j = k + 1 t t a n h ′ W s \prod_{j=k+1}^{t}tanh^{'}W_s ∏j=k+1ttanh′Ws 就会趋近于无穷,这就是RNN中梯度消失和爆炸的原因。
至于怎么避免这种现象,让我在看看 ∂ L t ∂ W x = ∑ k = 0 t ∂ L t ∂ O t ∂ O t ∂ S t ( ∏ j = k + 1 t ∂ S j ∂ S j − 1 ) ∂ S k ∂ W x \frac{\partial{L_t}}{\partial{W_x}}=\sum_{k=0}^{t}\frac{\partial{L_t}}{\partial{O_t}}\frac{\partial{O_t}}{\partial{S_t}}(\prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}})\frac{\partial{S_k}}{\partial{W_x}} ∂Wx∂Lt=∑k=0t∂Ot∂Lt∂St∂Ot(∏j=k+1t∂Sj−1∂Sj)∂Wx∂Sk梯度消失和爆炸的根本原因就是 ∏ j = k + 1 t ∂ S j ∂ S j − 1 \prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}} ∏j=k+1t∂Sj−1∂Sj这一坨,要消除这种情况就需要把这一坨在求偏导的过程中去掉,至于怎么去掉,一种办法就是使 ∏ j = k + 1 t ∂ S j ∂ S j − 1 ≈ 1 \prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}}\approx1 ∏j=k+1t∂Sj−1∂Sj≈1另一种办法就是使 ∏ j = k + 1 t ∂ S j ∂ S j − 1 ≈ 0 \prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}}\approx0 ∏j=k+1t∂Sj−1∂Sj≈0 。其实这就是LSTM做的事情,至于细节问题下节将进行介绍。
2.LSTM如何解决梯度消失问题
先上一张LSTM的经典图:
而LSTM可以抽象成这样:
三个×分别代表的就是forget gate,input gate,output gate,而我认为LSTM最关键的就是forget gate这个部件。这三个gate是如何控制流入流出的呢,其实就是通过下面
f
t
,
i
t
,
o
t
f_t,i_t,o_t
ft,it,ot 三个函数来控制,因为
σ
(
x
)
\sigma(x)
σ(x)(代表sigmoid函数) 的值是介于0到1之间的,刚好用趋近于0时表示流入不能通过gate,趋近于1时表示流入可以通过gate。
当前的状态
S
t
=
f
t
S
t
−
1
+
i
t
X
t
S_t=f_tS_{t-1}+i_tX_t
St=ftSt−1+itXt类似与传统RNN
S
t
=
W
s
S
t
−
1
+
W
x
X
t
+
b
1
S_t=W_sS_{t-1}+W_xX_t+b_1
St=WsSt−1+WxXt+b1。将LSTM的状态表达式展开后得:
如果加上激活函数,
S
t
=
t
a
n
h
[
σ
(
W
f
X
t
+
b
f
)
S
t
−
1
+
σ
(
W
i
X
t
+
b
i
)
X
t
]
S_t=tanh[\sigma(W_fX_t+b_f)S_{t-1}+\sigma(W_iX_t+b_i)X_t]
St=tanh[σ(WfXt+bf)St−1+σ(WiXt+bi)Xt]
这篇文章中传统RNN求偏导的过程包含
对于LSTM同样也包含这样的一项,但是在LSTM中
假设
Z
=
t
a
n
h
′
(
x
)
σ
(
y
)
Z=tanh'(x)\sigma(y)
Z=tanh′(x)σ(y) ,则
Z
Z
Z的函数图像如下图所示:
可以看到该函数值基本上不是0就是1。
这篇文章中传统RNN的求偏导过程:
如果在LSTM中上式可能就会变成:
因为 ∏ j = k + 1 t ∂ S j ∂ S j − 1 = ∏ j = k + 1 t t a n h ′ σ ( W f X t + b f ) ≈ 0 ∣ 1 \prod_{j=k+1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}}=\prod_{j=k+1}^{t}tanh'\sigma(W_fX_t+b_f)\approx0|1 ∏j=k+1t∂Sj−1∂Sj=∏j=k+1ttanh′σ(WfXt+bf)≈0∣1 ,这样就解决了传统RNN中梯度消失的问题。