参考资料:
Why LSTMs Stop Your Gradients From Vanishing: A View from the Backwards Pass
前言
早些时候写了一篇关于RNN/LSTM的博客,介绍了RNN、LSTM的基本原理,其中提到了RNN梯度消失的问题,借机引出了LSTM。当时的文章中只写到了LSTM可以缓解梯度消失,但没有写明原因,原因是当时太想当然了,没有仔细思考这个问题。由于那篇博文的阅读量很多,本着负责的态度,现在重新把这个问题翻出来好好解释一下。
本文首先简单回顾RNN产生梯度消失的原因,然后阐述LSTM缓解梯度消失真正的原因。
回归RNN产生梯度消失的原因
上图为RNN的结构图,对于t时刻,其前向传播的公式为:
h
(
t
)
=
ϕ
(
U
x
(
t
)
+
W
h
(
t
−
1
)
+
b
)
h^{(t)}=\phi(Ux^{(t)}+Wh^{(t-1)}+b)
h(t)=ϕ(Ux(t)+Wh(t−1)+b)
o
(
t
)
=
V
h
(
t
)
+
c
o^{(t)}=Vh^{(t)}+c
o(t)=Vh(t)+c
y
^
(
t
)
=
σ
(
o
(
t
)
)
\widehat{y}^{(t)}=\sigma(o^{(t)})
y
(t)=σ(o(t)) 其中
ϕ
(
)
\phi()
ϕ()为激活函数,一般来说会选择tanh函数,b为偏置;
o
(
t
)
o^{(t)}
o(t)为输出,
y
^
(
t
)
\widehat{y}^{(t)}
y
(t)为最终预测值;
σ
\sigma
σ为网络尾部的函数,若为分类任务,一般为softmax。
RNN的反向传播为BPTT,需要寻优的参数有三个,分别是U、V、W,三者的偏导数为:
∂
L
∂
V
=
∑
t
=
1
n
∂
L
(
t
)
∂
o
(
t
)
⋅
∂
o
(
t
)
∂
V
\frac{\partial L}{\partial V}=\sum_{t=1}^{n}\frac{\partial L^{(t)}}{\partial o^{(t)}}\cdot \frac{\partial o^{(t)}}{\partial V}
∂V∂L=t=1∑n∂o(t)∂L(t)⋅∂V∂o(t)
∂
L
(
3
)
∂
W
=
∂
L
(
3
)
∂
o
(
3
)
∂
o
(
3
)
∂
h
(
3
)
∂
h
(
3
)
∂
W
+
∂
L
(
3
)
∂
o
(
3
)
∂
o
(
3
)
∂
h
(
3
)
∂
h
(
3
)
∂
h
(
2
)
∂
h
(
2
)
∂
W
+
∂
L
(
3
)
∂
o
(
3
)
∂
o
(
3
)
∂
h
(
3
)
∂
h
(
3
)
∂
h
(
2
)
∂
h
(
2
)
∂
h
(
1
)
∂
h
(
1
)
∂
W
\frac{\partial L^{(3)}}{\partial W}=\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial W}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial W}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial h^{(1)}}\frac{\partial h^{(1)}}{\partial W}
∂W∂L(3)=∂o(3)∂L(3)∂h(3)∂o(3)∂W∂h(3)+∂o(3)∂L(3)∂h(3)∂o(3)∂h(2)∂h(3)∂W∂h(2)+∂o(3)∂L(3)∂h(3)∂o(3)∂h(2)∂h(3)∂h(1)∂h(2)∂W∂h(1)
∂
L
(
3
)
∂
U
=
∂
L
(
3
)
∂
o
(
3
)
∂
o
(
3
)
∂
h
(
3
)
∂
h
(
3
)
∂
U
+
∂
L
(
3
)
∂
o
(
3
)
∂
o
(
3
)
∂
h
(
3
)
∂
h
(
3
)
∂
h
(
2
)
∂
h
(
2
)
∂
U
+
∂
L
(
3
)
∂
o
(
3
)
∂
o
(
3
)
∂
h
(
3
)
∂
h
(
3
)
∂
h
(
2
)
∂
h
(
2
)
∂
h
(
1
)
∂
h
(
1
)
∂
U
\frac{\partial L^{(3)}}{\partial U}=\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial U}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial U}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial h^{(1)}}\frac{\partial h^{(1)}}{\partial U}
∂U∂L(3)=∂o(3)∂L(3)∂h(3)∂o(3)∂U∂h(3)+∂o(3)∂L(3)∂h(3)∂o(3)∂h(2)∂h(3)∂U∂h(2)+∂o(3)∂L(3)∂h(3)∂o(3)∂h(2)∂h(3)∂h(1)∂h(2)∂U∂h(1)
我们根据上面两个式子可以写出L在t时刻对W和U偏导数的通式:
∂
L
(
t
)
∂
W
=
∑
k
=
1
t
∂
L
(
t
)
∂
o
(
t
)
∂
o
(
t
)
∂
h
(
t
)
(
∏
j
=
k
+
1
t
∂
h
(
j
)
∂
h
(
j
−
1
)
)
∂
h
(
k
)
∂
W
\frac{\partial L^{(t)}}{\partial W}=\sum_{k=1}^{t}\frac{\partial L^{(t)}}{\partial o^{(t)}}\frac{\partial o^{(t)}}{\partial h^{(t)}}(\prod_{j=k+1}^{t}\frac{\partial h^{(j)}}{\partial h^{(j-1)}})\frac{\partial h^{(k)}}{\partial W}
∂W∂L(t)=k=1∑t∂o(t)∂L(t)∂h(t)∂o(t)(j=k+1∏t∂h(j−1)∂h(j))∂W∂h(k)
∂
L
(
t
)
∂
U
=
∑
k
=
1
t
∂
L
(
t
)
∂
o
(
t
)
∂
o
(
t
)
∂
h
(
t
)
(
∏
j
=
k
+
1
t
∂
h
(
j
)
∂
h
(
j
−
1
)
)
∂
h
(
k
)
∂
U
\frac{\partial L^{(t)}}{\partial U}=\sum_{k=1}^{t}\frac{\partial L^{(t)}}{\partial o^{(t)}}\frac{\partial o^{(t)}}{\partial h^{(t)}}(\prod_{j=k+1}^{t}\frac{\partial h^{(j)}}{\partial h^{(j-1)}})\frac{\partial h^{(k)}}{\partial U}
∂U∂L(t)=k=1∑t∂o(t)∂L(t)∂h(t)∂o(t)(j=k+1∏t∂h(j−1)∂h(j))∂U∂h(k)
整体的偏导公式就是将其按时刻再一一加起来。
激活函数是嵌套在里面的,如果我们把激活函数放进去,拿出中间累乘的那部分:
∏
j
=
k
+
1
t
∂
h
j
∂
h
j
−
1
=
∏
j
=
k
+
1
t
t
a
n
h
′
⋅
W
s
\prod_{j=k+1}^{t}{\frac{\partial{h^{j}}}{\partial{h^{j-1}}}} = \prod_{j=k+1}^{t}{tanh^{'}}\cdot W_{s}
j=k+1∏t∂hj−1∂hj=j=k+1∏ttanh′⋅Ws 或是
∏
j
=
k
+
1
t
∂
h
j
∂
h
j
−
1
=
∏
j
=
k
+
1
t
s
i
g
m
o
i
d
′
⋅
W
s
\prod_{j=k+1}^{t}{\frac{\partial{h^{j}}}{\partial{h^{j-1}}}} = \prod_{j=k+1}^{t}{sigmoid^{'}}\cdot W_{s}
j=k+1∏t∂hj−1∂hj=j=k+1∏tsigmoid′⋅Ws
我们会发现累乘会导致激活函数导数的累乘,同时还有权值的累乘。若权值小于1,几乎不可避免地会导致“梯度消失“现象;如果权值很大,可能会导致“梯度爆炸“现象。在实践过程中,还是梯度消失现象更容易发生。
LSTM缓解梯度消失的原因
先来简单回顾LSTM,一图流。
其中最关键的就是cell state的传播流程,大部分网络上的传言说因为cell state的传播是靠加法的,所以有效抑制了梯度消失,这是扯淡的。
cell state的传播公式在远古时期(1997年版本)的LSTM是这样的:
C
t
=
C
t
−
1
+
i
C
~
t
C_{t}=C_{t-1}+i \widetilde{C}_{t}
Ct=Ct−1+iC
t 没错,没有遗忘门!如果在这个版本说是因为加法有效抑制了梯度消失,那还多多少少有几分道理。为什么说道理只有几分,是因为很多人有一个误解:远古版本的cell state的求导导数为1,梯度可以恒定传播,很多人忽略了后面
i
C
~
t
i \widetilde{C}_{t}
iC
t。不过对于远古版本的LSTM的代码来说,cell state反向传播导数确实为1,以为梯度截断去掉了后面那部分的影响。原文截取如下:
However,to ensure non-decaying error backprop through internal states of memory cells, as with truncated BPTT (e.g.,Williams and Peng 1990), errors arriving at “memory cell net inputs” [the cell output, input, forget, and candidate gates] …do not get propagated back further in time (although they do serve to change the incoming weights).Only within memory cells [the cell state],errors are propagated back through previous internal states.
对于远古版本的LSTM来说,即使考虑了后面那部分,导数依然不会小于1,梯度消失现象确实也就不会发生,但为什么好端端的后来就加了个遗忘门呢?
原因是cell state不能只进不出,当序列过长的时候,cell state后面会变成庞然大物,反而影响模型的效果,所以后来加入了遗忘门。加入遗忘门这个操作,可以说是更容易让LSTM产生梯度消失了,但相比遗忘门带来的收益,这点儿损失不算什么。
但是现在的LSTM在缓解梯度消失问题上的表现也是非常不错了,其原因还是在于BPTT的过程中,我们来看一下现版本LSTM的cell state反向传播的公式:
∂
C
t
∂
C
t
−
1
=
∂
C
t
∂
f
t
∂
f
t
∂
h
t
−
1
∂
h
t
−
1
∂
C
t
−
1
+
∂
C
t
∂
i
t
∂
i
t
∂
h
t
−
1
∂
h
t
−
1
∂
C
t
−
1
+
∂
C
t
∂
C
~
t
∂
C
~
t
∂
h
t
−
1
∂
h
t
−
1
∂
C
t
−
1
+
∂
C
t
∂
C
t
−
1
\begin{aligned} \frac{\partial C_{t}}{\partial C_{t-1}} &=\frac{\partial C_{t}}{\partial f_{t}} \frac{\partial f_{t}}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial C_{t-1}}+\frac{\partial C_{t}}{\partial i_{t}} \frac{\partial i_{t}}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial C_{t-1}} \\ &+\frac{\partial C_{t}}{\partial \widetilde{C}_{t}} \frac{\partial \widetilde{C}_{t}}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial C_{t-1}}+\frac{\partial C_{t}}{\partial C_{t-1}} \end{aligned}
∂Ct−1∂Ct=∂ft∂Ct∂ht−1∂ft∂Ct−1∂ht−1+∂it∂Ct∂ht−1∂it∂Ct−1∂ht−1+∂C
t∂Ct∂ht−1∂C
t∂Ct−1∂ht−1+∂Ct−1∂Ct 这才是考虑了
i
C
~
t
i \widetilde{C}_{t}
iC
t的cell state反向求导公式,进一步推导得到:
∂
C
t
∂
C
t
−
1
=
C
t
−
1
σ
′
(
⋅
)
W
f
∗
o
t
−
1
tanh
′
(
C
t
−
1
)
+
C
~
t
σ
′
(
⋅
)
W
i
∗
o
t
−
1
tanh
′
(
C
t
−
1
)
+
i
t
tanh
′
(
⋅
)
W
C
∗
o
t
−
1
tanh
′
(
C
t
−
1
)
+
f
t
\begin{aligned} \frac{\partial C_{t}}{\partial C_{t-1}} &=C_{t-1} \sigma^{\prime}(\cdot) W_{f} * o_{t-1} \tanh ^{\prime}\left(C_{t-1}\right) \\ &+\widetilde{C}_{t} \sigma^{\prime}(\cdot) W_{i} * o_{t-1} \tanh ^{\prime}\left(C_{t-1}\right) \\ &+i_{t} \tanh ^{\prime}(\cdot) W_{C} * o_{t-1} \tanh ^{\prime}\left(C_{t-1}\right) \\ &+f_{t} \end{aligned}
∂Ct−1∂Ct=Ct−1σ′(⋅)Wf∗ot−1tanh′(Ct−1)+C
tσ′(⋅)Wi∗ot−1tanh′(Ct−1)+ittanh′(⋅)WC∗ot−1tanh′(Ct−1)+ft
这只是一步的推导,如果是多个时间步,就是多个类似公式的累乘。从这一步的结果中我们可以发现,其结果的取值范围并不一定局限在[0,1]中,而是有可能大于1的。
那么什么情况下大于1?
这个由LSTM自身的权值决定,那权值从何而来?当然是学习得到的,这便是LSTM牛逼之处,依靠学习得到权值去控制依赖的长度,这便是LSTM缓解梯度消失的真相。综上可以总结为两个事实:
1、cell state传播函数中的“加法”结构确实起了一定作用,它使得导数有可能大于1;
2、LSTM中逻辑门的参数可以一定程度控制不同时间步梯度消失的程度。
最后,LSTM依然不能完全解决梯度消失这个问题,有文献表示序列长度一般到了三百多仍然会出现梯度消失现象。如果想彻底规避这个问题,还是transformer好用。