LSTM(Long Short-Term Memory)

LSTM(Long Short-Term Memory)

前面的两篇博客介绍了基本的循环神经网络RNN(recurrent neural network):

但是基本的RNN(之所以强调是基本的RNN,是因为LSTM本质上也是一种RNN,下面在说RNN就代指基本的RNN)也存在一些缺点,举个例子,比如我们有个句子生成的任务,有下面两句话(摘自ng deep learning课《sequence model》):

1、The cat, which already ate …, was full.
2、The cats, which already ate …, were full.

假如我们到动词部分,我们要根据前面的cat是单数还是复数来生成对应的was还是were,但是遗憾的是,基本的RNN是无法捕获这种长期依赖的,原因就在一旦需要长期依赖,RNN就会产生梯度消失。关于RNN为什么会梯度消失,这一点和深度网络中梯度消失原因是一样的,在博客RNN(recurrent neural network)(一)——基础知识里也解释过了。知乎上有篇文章从数学公式上解释了RNN梯度消失和爆炸的原理,可参考文章:RNN梯度消失和爆炸的原因。总结来看基本的RNN的缺点是:

无法处理长期依赖的问题(原因在于 梯度消失)

因此,LSTM被提出以解决这个问题。有一点需要注意的是:LSTM只能极大的缓解RNN的梯度消失,但不能从根本上解决,所幸的是大多数的任务场景下实验表明,LSTM都能够取得很好的结果。那么LSTM究竟是如何缓解梯度消失的呢?这里我们先把这个问题留在这,等我们介绍完LSTM的原理,再来回头解答这个问题,这样会比较清楚。

一、LSTM的结构

先来看看单个LSTM单元的结构图,以及处理序列展开后的LSTM结构图,这样大家会有一个直观的认识。

其中 Γ f \Gamma _{f} Γf表示遗忘门, Γ i \Gamma _{i} Γi表示输入门, Γ o \Gamma _{o} Γo表示输出门。其中,个人认为最重要的就是遗忘门。那么这几个门分别起到了什么作用?我们来看下:

  • forget gate(遗忘门):举个例子来解释下遗忘门的作用(参考ng deep learning课):
    lets assume we are reading words in a piece of text, and want use an LSTM to keep track of grammatical structures, such as whether the subject is singular or plural. If the subject changes from a singular word to a plural word, we need to find a way to get rid of our previously stored memory value of the singular/plural state.
    遗忘门的公式为 Γ f &lt; t &gt; = σ ( W f [ a &lt; t − 1 &gt; , x &lt; t &gt; ] + b f ) \Gamma_{f}^{&lt;t&gt;} = \sigma (W_{f}[a^{&lt;t-1&gt;},x^{&lt;t&gt;}] + b_f) Γf<t>=σ(Wf[a<t1>,x<t>]+bf) σ \sigma σ为sigmoid函数,因此 Γ f \Gamma _{f} Γf的值在0到1之间,也就是说遗忘门通过看上一个隐藏状态( a &lt; t − 1 &gt; a^{&lt;t-1&gt;} a<t1>)和当前的输入( x &lt; t &gt; x^{&lt;t&gt;} x<t>)来得到一个值( Γ f \Gamma _{f} Γf),然后用 Γ f \Gamma _{f} Γf去点乘 C &lt; t − 1 &gt; C^{&lt;t-1&gt;} C<t1>(上一个memory cell)去决定上一个memory cell的信息是否保留,0表示丢弃,1表示保留。

  • input gate(输入门):这个有的也叫update gate。输入门的作用是为了更新cell state的时候,来决定哪些值需要被更新。

  • output gate(输出门):决定cell state里哪些值应该被输出(即下一个cell state的值)。

再上一张更加清晰地LSTM cell图,来自斯坦福的CS224D

LSTM展开图如下所示:

二、LSTM的计算流程

关于LSTM的计算过程,博客 Understanding LSTM Networks中给出了详细了过程,建议大家认真看一下。我这里只简要的写一下流程(参考了上面的博客):
1、首先计算得到 Γ f &lt; t &gt; = σ ( W f [ a &lt; t − 1 &gt; , x &lt; t &gt; ] + b f ) \Gamma_{f}^{&lt;t&gt;} = \sigma (W_{f}[a^{&lt;t-1&gt;},x^{&lt;t&gt;}] + b_f) Γf<t>=σ(Wf[a<t1>,x<t>]+bf)
2、计算 Γ i &lt; t &gt; = σ ( W i [ a &lt; t − 1 &gt; , x &lt; t &gt; ] + b i ) \Gamma_{i}^{&lt;t&gt;} = \sigma (W_{i}[a^{&lt;t-1&gt;},x^{&lt;t&gt;}] + b_i) Γi<t>=σ(Wi[a<t1>,x<t>]+bi) C ~ &lt; t &gt; = tanh ⁡ ( W c [ a &lt; t − 1 &gt; , x &lt; t &gt; ] + b c ) \widetilde{C}^{&lt;t&gt;} = \tanh (W_{c}[a^{&lt;t-1&gt;},x^{&lt;t&gt;}] + b_c) C <t>=tanh(Wc[a<t1>,x<t>]+bc)
3、得到新的cell state: C &lt; t &gt; = Γ f &lt; t &gt; ⊙ C &lt; t − 1 &gt; + Γ i &lt; t &gt; ⊙ C ~ &lt; t &gt; C^{&lt;t&gt;} = \Gamma_{f}^{&lt;t&gt;}\odot C^{&lt;t-1&gt;} + \Gamma_{i}^{&lt;t&gt;}\odot \widetilde{C}^{&lt;t&gt;} C<t>=Γf<t>C<t1>+Γi<t>C <t>,其中 Γ f &lt; t &gt; ⊙ C &lt; t − 1 &gt; \Gamma_{f}^{&lt;t&gt;}\odot C^{&lt;t-1&gt;} Γf<t>C<t1>遗忘门决定上一个memory cell里有多少信息被保留下来, Γ i &lt; t &gt; ⊙ C ~ &lt; t &gt; \Gamma_{i}^{&lt;t&gt;}\odot \widetilde{C}^{&lt;t&gt;} Γi<t>C <t>表示哪些新的信息被添加到当前的memory cell里。
4、计算 Γ o &lt; t &gt; = σ ( W o [ a &lt; t − 1 &gt; , x &lt; t &gt; ] + b o ) \Gamma_{o}^{&lt;t&gt;} = \sigma (W_{o}[a^{&lt;t-1&gt;},x^{&lt;t&gt;}] + b_o) Γo<t>=σ(Wo[a<t1>,x<t>]+bo)
5、决定当前memory cell中哪些信息被输出: a &lt; t &gt; = Γ o &lt; t &gt; ⊙ t a n h ( c &lt; t &gt; ) a^{&lt;t&gt;} = \Gamma_{o}^{&lt;t&gt;}\odot tanh(c^{&lt;t&gt;}) a<t>=Γo<t>tanh(c<t>)

关于LSTM的简单介绍到这就介绍完了,现在回过头去回答上面提出的问题:“LSTM是如何缓解梯度消失的?”

RNN产生梯度消失的根本原因在于连乘导致(梯度小于1,将->0,大于1,将->无穷,梯度爆炸),这点可以从bp的公式推导得到,这里不就推导了,博客漫谈LSTM系列的梯度问题做了推导,可以参考这一篇。而LSTM的魅力在于把连乘变成了相加,公式推导同样参考上面的文章。这里从上面的文章中摘取最重要的一部分:



下面是一些个人认为对理解LSTM比较好的文章博客,建议大家仔细看一看:

[1]: Understanding LSTM Networks
[2]: 漫谈LSTM系列的梯度问题
[3]: http://cs224d.stanford.edu/lecture_notes/notes4.pdf





评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值