为什么LSTM可以缓解梯度消失?

参考资料:
Why LSTMs Stop Your Gradients From Vanishing: A View from the Backwards Pass

RNN


前言

  早些时候写了一篇关于RNN/LSTM的博客,介绍了RNN、LSTM的基本原理,其中提到了RNN梯度消失的问题,借机引出了LSTM。当时的文章中只写到了LSTM可以缓解梯度消失,但没有写明原因,原因是当时太想当然了,没有仔细思考这个问题。由于那篇博文的阅读量很多,本着负责的态度,现在重新把这个问题翻出来好好解释一下。

  本文首先简单回顾RNN产生梯度消失的原因,然后阐述LSTM缓解梯度消失真正的原因。

回归RNN产生梯度消失的原因

这里写图片描述
  上图为RNN的结构图,对于t时刻,其前向传播的公式为:
h ( t ) = ϕ ( U x ( t ) + W h ( t − 1 ) + b ) h^{(t)}=\phi(Ux^{(t)}+Wh^{(t-1)}+b) h(t)=ϕ(Ux(t)+Wh(t1)+b) o ( t ) = V h ( t ) + c o^{(t)}=Vh^{(t)}+c o(t)=Vh(t)+c y ^ ( t ) = σ ( o ( t ) ) \widehat{y}^{(t)}=\sigma(o^{(t)}) y (t)=σ(o(t))   其中 ϕ ( ) \phi() ϕ()为激活函数,一般来说会选择tanh函数,b为偏置; o ( t ) o^{(t)} o(t)为输出, y ^ ( t ) \widehat{y}^{(t)} y (t)为最终预测值; σ \sigma σ为网络尾部的函数,若为分类任务,一般为softmax。

  RNN的反向传播为BPTT,需要寻优的参数有三个,分别是U、V、W,三者的偏导数为:
∂ L ∂ V = ∑ t = 1 n ∂ L ( t ) ∂ o ( t ) ⋅ ∂ o ( t ) ∂ V \frac{\partial L}{\partial V}=\sum_{t=1}^{n}\frac{\partial L^{(t)}}{\partial o^{(t)}}\cdot \frac{\partial o^{(t)}}{\partial V} VL=t=1no(t)L(t)Vo(t) ∂ L ( 3 ) ∂ W = ∂ L ( 3 ) ∂ o ( 3 ) ∂ o ( 3 ) ∂ h ( 3 ) ∂ h ( 3 ) ∂ W + ∂ L ( 3 ) ∂ o ( 3 ) ∂ o ( 3 ) ∂ h ( 3 ) ∂ h ( 3 ) ∂ h ( 2 ) ∂ h ( 2 ) ∂ W + ∂ L ( 3 ) ∂ o ( 3 ) ∂ o ( 3 ) ∂ h ( 3 ) ∂ h ( 3 ) ∂ h ( 2 ) ∂ h ( 2 ) ∂ h ( 1 ) ∂ h ( 1 ) ∂ W \frac{\partial L^{(3)}}{\partial W}=\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial W}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial W}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial h^{(1)}}\frac{\partial h^{(1)}}{\partial W} WL(3)=o(3)L(3)h(3)o(3)Wh(3)+o(3)L(3)h(3)o(3)h(2)h(3)Wh(2)+o(3)L(3)h(3)o(3)h(2)h(3)h(1)h(2)Wh(1) ∂ L ( 3 ) ∂ U = ∂ L ( 3 ) ∂ o ( 3 ) ∂ o ( 3 ) ∂ h ( 3 ) ∂ h ( 3 ) ∂ U + ∂ L ( 3 ) ∂ o ( 3 ) ∂ o ( 3 ) ∂ h ( 3 ) ∂ h ( 3 ) ∂ h ( 2 ) ∂ h ( 2 ) ∂ U + ∂ L ( 3 ) ∂ o ( 3 ) ∂ o ( 3 ) ∂ h ( 3 ) ∂ h ( 3 ) ∂ h ( 2 ) ∂ h ( 2 ) ∂ h ( 1 ) ∂ h ( 1 ) ∂ U \frac{\partial L^{(3)}}{\partial U}=\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial U}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial U}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial h^{(1)}}\frac{\partial h^{(1)}}{\partial U} UL(3)=o(3)L(3)h(3)o(3)Uh(3)+o(3)L(3)h(3)o(3)h(2)h(3)Uh(2)+o(3)L(3)h(3)o(3)h(2)h(3)h(1)h(2)Uh(1)

  我们根据上面两个式子可以写出L在t时刻对W和U偏导数的通式:
∂ L ( t ) ∂ W = ∑ k = 1 t ∂ L ( t ) ∂ o ( t ) ∂ o ( t ) ∂ h ( t ) ( ∏ j = k + 1 t ∂ h ( j ) ∂ h ( j − 1 ) ) ∂ h ( k ) ∂ W \frac{\partial L^{(t)}}{\partial W}=\sum_{k=1}^{t}\frac{\partial L^{(t)}}{\partial o^{(t)}}\frac{\partial o^{(t)}}{\partial h^{(t)}}(\prod_{j=k+1}^{t}\frac{\partial h^{(j)}}{\partial h^{(j-1)}})\frac{\partial h^{(k)}}{\partial W} WL(t)=k=1to(t)L(t)h(t)o(t)(j=k+1th(j1)h(j))Wh(k) ∂ L ( t ) ∂ U = ∑ k = 1 t ∂ L ( t ) ∂ o ( t ) ∂ o ( t ) ∂ h ( t ) ( ∏ j = k + 1 t ∂ h ( j ) ∂ h ( j − 1 ) ) ∂ h ( k ) ∂ U \frac{\partial L^{(t)}}{\partial U}=\sum_{k=1}^{t}\frac{\partial L^{(t)}}{\partial o^{(t)}}\frac{\partial o^{(t)}}{\partial h^{(t)}}(\prod_{j=k+1}^{t}\frac{\partial h^{(j)}}{\partial h^{(j-1)}})\frac{\partial h^{(k)}}{\partial U} UL(t)=k=1to(t)L(t)h(t)o(t)(j=k+1th(j1)h(j))Uh(k)
  整体的偏导公式就是将其按时刻再一一加起来。

  激活函数是嵌套在里面的,如果我们把激活函数放进去,拿出中间累乘的那部分:
∏ j = k + 1 t ∂ h j ∂ h j − 1 = ∏ j = k + 1 t t a n h ′ ⋅ W s \prod_{j=k+1}^{t}{\frac{\partial{h^{j}}}{\partial{h^{j-1}}}} = \prod_{j=k+1}^{t}{tanh^{'}}\cdot W_{s} j=k+1thj1hj=j=k+1ttanhWs  或是
∏ j = k + 1 t ∂ h j ∂ h j − 1 = ∏ j = k + 1 t s i g m o i d ′ ⋅ W s \prod_{j=k+1}^{t}{\frac{\partial{h^{j}}}{\partial{h^{j-1}}}} = \prod_{j=k+1}^{t}{sigmoid^{'}}\cdot W_{s} j=k+1thj1hj=j=k+1tsigmoidWs

  我们会发现累乘会导致激活函数导数的累乘,同时还有权值的累乘。若权值小于1,几乎不可避免地会导致“梯度消失“现象;如果权值很大,可能会导致“梯度爆炸“现象。在实践过程中,还是梯度消失现象更容易发生。

LSTM缓解梯度消失的原因

  先来简单回顾LSTM,一图流。
在这里插入图片描述
  其中最关键的就是cell state的传播流程,大部分网络上的传言说因为cell state的传播是靠加法的,所以有效抑制了梯度消失,这是扯淡的。

  cell state的传播公式在远古时期(1997年版本)的LSTM是这样的:
C t = C t − 1 + i C ~ t C_{t}=C_{t-1}+i \widetilde{C}_{t} Ct=Ct1+iC t   没错,没有遗忘门!如果在这个版本说是因为加法有效抑制了梯度消失,那还多多少少有几分道理。为什么说道理只有几分,是因为很多人有一个误解:远古版本的cell state的求导导数为1,梯度可以恒定传播,很多人忽略了后面 i C ~ t i \widetilde{C}_{t} iC t。不过对于远古版本的LSTM的代码来说,cell state反向传播导数确实为1,以为梯度截断去掉了后面那部分的影响。原文截取如下:

However,to ensure non-decaying error backprop through internal states of memory cells, as with truncated BPTT (e.g.,Williams and Peng 1990), errors arriving at “memory cell net inputs” [the cell output, input, forget, and candidate gates] …do not get propagated back further in time (although they do serve to change the incoming weights).Only within memory cells [the cell state],errors are propagated back through previous internal states.

  对于远古版本的LSTM来说,即使考虑了后面那部分,导数依然不会小于1,梯度消失现象确实也就不会发生,但为什么好端端的后来就加了个遗忘门呢?

  原因是cell state不能只进不出,当序列过长的时候,cell state后面会变成庞然大物,反而影响模型的效果,所以后来加入了遗忘门。加入遗忘门这个操作,可以说是更容易让LSTM产生梯度消失了,但相比遗忘门带来的收益,这点儿损失不算什么。

  但是现在的LSTM在缓解梯度消失问题上的表现也是非常不错了,其原因还是在于BPTT的过程中,我们来看一下现版本LSTM的cell state反向传播的公式:
∂ C t ∂ C t − 1 = ∂ C t ∂ f t ∂ f t ∂ h t − 1 ∂ h t − 1 ∂ C t − 1 + ∂ C t ∂ i t ∂ i t ∂ h t − 1 ∂ h t − 1 ∂ C t − 1 + ∂ C t ∂ C ~ t ∂ C ~ t ∂ h t − 1 ∂ h t − 1 ∂ C t − 1 + ∂ C t ∂ C t − 1 \begin{aligned} \frac{\partial C_{t}}{\partial C_{t-1}} &=\frac{\partial C_{t}}{\partial f_{t}} \frac{\partial f_{t}}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial C_{t-1}}+\frac{\partial C_{t}}{\partial i_{t}} \frac{\partial i_{t}}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial C_{t-1}} \\ &+\frac{\partial C_{t}}{\partial \widetilde{C}_{t}} \frac{\partial \widetilde{C}_{t}}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial C_{t-1}}+\frac{\partial C_{t}}{\partial C_{t-1}} \end{aligned} Ct1Ct=ftCtht1ftCt1ht1+itCtht1itCt1ht1+C tCtht1C tCt1ht1+Ct1Ct  这才是考虑了 i C ~ t i \widetilde{C}_{t} iC t的cell state反向求导公式,进一步推导得到:
∂ C t ∂ C t − 1 = C t − 1 σ ′ ( ⋅ ) W f ∗ o t − 1 tanh ⁡ ′ ( C t − 1 ) + C ~ t σ ′ ( ⋅ ) W i ∗ o t − 1 tanh ⁡ ′ ( C t − 1 ) + i t tanh ⁡ ′ ( ⋅ ) W C ∗ o t − 1 tanh ⁡ ′ ( C t − 1 ) + f t \begin{aligned} \frac{\partial C_{t}}{\partial C_{t-1}} &=C_{t-1} \sigma^{\prime}(\cdot) W_{f} * o_{t-1} \tanh ^{\prime}\left(C_{t-1}\right) \\ &+\widetilde{C}_{t} \sigma^{\prime}(\cdot) W_{i} * o_{t-1} \tanh ^{\prime}\left(C_{t-1}\right) \\ &+i_{t} \tanh ^{\prime}(\cdot) W_{C} * o_{t-1} \tanh ^{\prime}\left(C_{t-1}\right) \\ &+f_{t} \end{aligned} Ct1Ct=Ct1σ()Wfot1tanh(Ct1)+C tσ()Wiot1tanh(Ct1)+ittanh()WCot1tanh(Ct1)+ft
  这只是一步的推导,如果是多个时间步,就是多个类似公式的累乘。从这一步的结果中我们可以发现,其结果的取值范围并不一定局限在[0,1]中,而是有可能大于1的。

  那么什么情况下大于1?

  这个由LSTM自身的权值决定,那权值从何而来?当然是学习得到的,这便是LSTM牛逼之处,依靠学习得到权值去控制依赖的长度,这便是LSTM缓解梯度消失的真相。综上可以总结为两个事实:

   1、cell state传播函数中的“加法”结构确实起了一定作用,它使得导数有可能大于1;
   2、LSTM中逻辑门的参数可以一定程度控制不同时间步梯度消失的程度。

  最后,LSTM依然不能完全解决梯度消失这个问题,有文献表示序列长度一般到了三百多仍然会出现梯度消失现象。如果想彻底规避这个问题,还是transformer好用。

  • 49
    点赞
  • 172
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值