LSTM解决RNN梯度消失、爆炸问题~~

转自知乎作者:沉默中的思索

原文链接:https://www.zhihu.com/people/chen-mo-zhong-de-si-suo/activities

先上一张LSTM的经典图:

 

至于这张图的详细介绍请参考:Understanding LSTM Networks

下面假设你已经阅读过Understanding LSTM Networks这篇文章了,并且了解了LSTM的组成结构。

RNN梯度消失和爆炸的原因这篇文章中提到的RNN结构可以抽象成下面这幅图:

 

而LSTM可以抽象成这样:

 

三个×分别代表的就是forget gate,input gate,output gate,而我认为LSTM最关键的就是forget gate这个部件。这三个gate是如何控制流入流出的呢,其实就是通过下面 f_{t},i_{t},o_{t} 三个函数来控制,因为 \sigma(x)(代表sigmoid函数) 的值是介于0到1之间的,刚好用趋近于0时表示流入不能通过gate,趋近于1时表示流入可以通过gate。

f_{t}=\sigma({W_{f}X_{t}}+b_{f})

i_{t}=\sigma({W_{i}X_{t}}+b_{i})

o_{i}=\sigma({W_{o}X_{t}}+b_{o})

当前的状态 S_{t}=f_{t}S_{t-1}+i_{t}X_{t}类似与传统RNN S_{t}=W_{s}S_{t-1}+W_{x}X_{t}+b_{1}。将LSTM的状态表达式展开后得:

S_{t}=\sigma(W_{f}X_{t}+b_{f})S_{t-1}+\sigma(W_{i}X_{t}+b_{i})X_{t}

如果加上激活函数, S_{t}=tanh\left[\sigma(W_{f}X_{t}+b_{f})S_{t-1}+\sigma(W_{i}X_{t}+b_{i})X_{t}\right]

RNN梯度消失和爆炸的原因这篇文章中传统RNN求偏导的过程包含 \prod_{j=k+1}^{t}\frac{\partial{S_{j}}}{\partial{S_{j-1}}}=\prod_{j=k+1}^{t}{tanh{'}W_{s}}

对于LSTM同样也包含这样的一项,但是在LSTM中 \prod_{j=k+1}^{t}\frac{\partial{S_{j}}}{\partial{S_{j-1}}}=\prod_{j=k+1}^{t}{tanh{’}\sigma({W_{f}X_{t}+b_{f}})}

假设 Z=tanh{'}(x)\sigma({y}) ,则 Z 的函数图像如下图所示:

 

可以看到该函数值基本上不是0就是1。

再看看RNN梯度消失和爆炸的原因这篇文章中传统RNN的求偏导过程:

\frac{\partial{L_{3}}}{\partial{W_{s}}}=\frac{\partial{L_{3}}}{\partial{O_{3}}}\frac{\partial{O_{3}}}{\partial{S_{3}}}\frac{\partial{S_{3}}}{\partial{W_{s}}}+\frac{\partial{L_{3}}}{\partial{O_{3}}}\frac{\partial{O_{3}}}{\partial{S_{3}}}\frac{\partial{S_{3}}}{\partial{S_{2}}}\frac{\partial{S_{2}}}{\partial{W_{s}}}+\frac{\partial{L_{3}}}{\partial{O_{3}}}\frac{\partial{O_{3}}}{\partial{S_{3}}}\frac{\partial{S_{3}}}{\partial{S_{2}}}\frac{\partial{S_{2}}}{\partial{S_{1}}}\frac{\partial{S_{1}}}{\partial{W_{s}}}

如果在LSTM中上式可能就会变成:

\frac{\partial{L_{3}}}{\partial{W_{s}}}=\frac{\partial{L_{3}}}{\partial{O_{3}}}\frac{\partial{O_{3}}}{\partial{S_{3}}}\frac{\partial{S_{3}}}{\partial{W_{s}}}+\frac{\partial{L_{3}}}{\partial{O_{3}}}\frac{\partial{O_{3}}}{\partial{S_{3}}}\frac{\partial{S_{2}}}{\partial{W_{s}}}+\frac{\partial{L_{3}}}{\partial{O_{3}}}\frac{\partial{O_{3}}}{\partial{S_{3}}}\frac{\partial{S_{1}}}{\partial{W_{s}}}

因为 \prod_{j=k+1}^{t}\frac{\partial{S_{j}}}{\partial{S_{j-1}}}=\prod_{j=k+1}^{t}{tanh{’}\sigma({W_{f}X_{t}+b_{f}})}\approx0|1 ,这样就解决了传统RNN中梯度消失的问题。

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值