小袁讲长短期记忆网络(LSTM)

一, 什么是长短期

LSTM全名“ Long Short-term Memory”,中文名翻译为长短期记忆网络。小袁我刚接触这个网络的时候,一度以为长短期记忆网络既可以建模序列问题中的长期时间依赖,又可以有效地捕捉到序列数据的短期时间依赖,因而被命名为长短期记忆网络。事实上这样理解对也不对,对在LSTM确实既有捕捉序列数据的长的时间依赖,又有捕捉短的时间依赖的特性上。不对在LSTM的特性并不像我们通俗理解的长短期。英文表达而言就是“Long Short-term Memory” 和 “Long Short term Memory”的差别吧。这篇博客我会重点讲下我对“长短期”的理解,如有不正确的地方还望各位不吝指教!

注:本博客部分图片公式来源于网络,侵删。转载请注明出处!

1.1 为何会有LSTM

据各路文献博客所言,LSTM的提出是为了解决循环神经网络(RNN)无法捕捉序列的长期时间依赖的不足,RNN的核心状态更新公式为
h t = f ( W i x t + W h h t − 1 ) h_t=f(W^ix_t + W^hh_{t-1}) ht=f(Wixt+Whht1)
其中, h t h_t ht 为RNN网络的隐藏层在时刻 t t t 的状态值, f ( ) f() f() 为RNN网络的激活函数,通常为 t a n h tanh tanh函数。

RNN的一种网络拓扑结构如下图所示:
在这里插入图片描述
由于第 t t t 时刻block内(上图中的绿框)的输入仅为上一时刻 t − 1 t-1 t1 的状态值 h t − 1 h_{t-1} ht1 和当前时刻的输入 x t x_t xt ,因而RNN无法捕捉到序列数据的长期依赖,仅能捕捉到序列数据的短期依赖,这导致了RNN网络在建模上的天然不足。

事实上,对RNN的这种理解是不对的。 这种有失偏颇的理解会进一步给自己理解LSTM带来困难。上述理解主要问题在于第 t t t 时刻block内(上图中的绿框)的输入之一 h t − 1 h_{t-1} ht1不是一个独立的变量,它的值通过 h t − 2 h_{t-2} ht2 x t − 1 x_{t-1} xt1 计算得到(即 h t − 1 h_{t-1} ht1 包含 h t − 2 h_{t-2} ht2 的特征信息 )。递归地, h t h_t ht 包含 t = 1 , 2 , . . . , t − 1 t=1,2,...,t-1 t=1,2,...,t1 的所有隐藏层的状态特征,因而RNN事实上是有建模长期时间依赖的能力的。既然如此,那为何会有RNN无法捕捉长期的序列时间依赖关系的说法呢?所谓无风不起浪啊。事实上,这个可以用“理想很丰满,现实很骨干”来比喻。尽管RNN能够完美的建模序列数据的长期依赖关系,但是它没法用啊,因为传统的RNN非常容易陷入梯度消失或梯度爆炸问题,这导致了RNN网络在实际使用中,无法捕捉到序列的长期依赖关系。事实上相应的长短期记忆网络LSTM也是因为它在实际应用中能够巧妙地避免梯度消失或梯度爆炸问题,使得它能够捕捉到长期的序列时间依赖关系。简言之,LSTM的提出是为了克服在实际应用中 ,RNN建模的长期时间依赖关系无法通过梯度优化的不足。

1.2 谈谈RNN的梯度消失和梯度爆炸

关于RNN的梯度消失和梯度爆炸问题,参考了知乎文章 ,并结合评论和我的理解做了部分修正。具体细节如下所示:

定义参数优化的损失函数
L = ∑ t = 0 T L t L=\sum_{t=0}^{T}L_{t} L=t=0TLt
则损失函数 L L L 对参数矩阵 W W W 的偏导数为
∂ L ∂ W = ∑ t = 0 T ∂ L t ∂ W \frac{\partial L}{\partial W} =\sum_{t=0}^{T}\frac{\partial L_t}{\partial W} WL=t=0TWLt
现考虑 t t t 时刻的损失函数误差对输出矩阵 W o W^o Wo ,隐藏层矩阵 W h W^h Wh,输入矩阵 W i W^i Wi 的偏导数,它们依次为
∂ L t ∂ W o = ∂ L t ∂ y t ∂ y t ∂ W o \frac{\partial L_t}{\partial W^o} =\frac{\partial L_t}{\partial y_t}\frac{\partial y_t}{\partial W^o} WoLt=ytLtWoyt

∂ L t ∂ W h = ∑ k = 0 t ∂ L t ∂ y t ∂ y t ∂ h t ( ∏ i = k + 1 t ∂ h i ∂ h i − 1 ) ∂ h k ∂ W h \frac{\partial L_t}{\partial W^h} =\sum_{k=0}^{t}\frac{\partial L_t}{\partial y_t}\frac{\partial y_t}{\partial h_t}(\prod_{i=k+1}^{t}\frac{\partial h_i}{\partial h_{i-1}})\frac{\partial h_{k}}{\partial W^h} WhLt=k=0tytLthtyt(i=k+1thi1hi)Whhk

∂ L t ∂ W i = ∑ k = 0 t ∂ L t ∂ y t ∂ y t ∂ h t ( ∏ i = k + 1 t ∂ h i ∂ h i − 1 ) ∂ h k ∂ W i \frac{\partial L_t}{\partial W^i} =\sum_{k=0}^{t}\frac{\partial L_t}{\partial y_t}\frac{\partial y_t}{\partial h_t}(\prod_{i=k+1}^{t}\frac{\partial h_i}{\partial h_{i-1}})\frac{\partial h_{k}}{\partial W^i} WiLt=k=0tytLthtyt(i=k+1thi1hi)Wihk

可以看到,在修正某个时刻 t t t 的误差时,离时刻 t t t越久远的时刻 k k k需要考虑到的隐藏层之间的偏导数 ∂ h i ∂ h i − 1 \frac{\partial h_i}{\partial h_{i-1}} hi1hi 的连乘次数越多。为了方便理解,我们假设 x t x_t xt h t h_t ht 均为一维变量,则 W h W_h Wh W i W_i Wi均为一维变量。 现在我们考虑下 ∏ i = k + 1 t ∂ h i ∂ h i − 1 \prod_{i=k+1}^{t}\frac{\partial h_i}{\partial h_{i-1}} i=k+1thi1hi ,由RNN的定义公式 h t = f ( W i x t + W h h t − 1 ) h_t=f(W^ix_t + W^hh_{t-1}) ht=f(Wixt+Whht1)
∂ h t ∂ h t − 1 = W h f ′ \frac{\partial h_t}{\partial h_{t-1}}=W_hf' ht1ht=Whf
因为 f f f 为sigmoid函数,所以它的导数的上届为0.25,所以有
∂ h t ∂ h t − 1 = W h f ′ ≤ 0.25 W h \frac{\partial h_t}{\partial h_{t-1}}=W_hf'\leq0.25W_h ht1ht=Whf0.25Wh
如果 W h ≤ 4 W_h\leq4 Wh4,则恒有 ∂ h t ∂ h t − 1 \frac{\partial h_t}{\partial h_{t-1}} ht1ht小于1。此时,若时刻 t t t与时刻 k k k的时差较大,则 ∏ i = k + 1 t ∂ h i ∂ h i − 1 \prod_{i=k+1}^{t}\frac{\partial h_i}{\partial h_{i-1}} i=k+1thi1hi趋于0。此时发生梯度消失现象。

或者某种情况下 W h f ′ ≥ 1 W_hf'\geq1 Whf1,即 ∂ h t ∂ h t − 1 \frac{\partial h_t}{\partial h_{t-1}} ht1ht大于1。此时,若时刻 t t t与时刻 k k k的时差较大,则 ∏ i = k + 1 t ∂ h i ∂ h i − 1 \prod_{i=k+1}^{t}\frac{\partial h_i}{\partial h_{i-1}} i=k+1thi1hi趋于无穷大。此时发生梯度爆炸现象。

既然如此,那我可否在初始化的时候选择一个好的 W h W_h Wh,使得刚好不会发生梯度消失和爆炸呢?事实上,一个好的初始化确实可以避免迭代算法在开始时更可能避免梯度消失和爆炸,然而随着迭代次数的增加,更新后的 W h W_h Wh就无法保证了,更多详细资料参考博客

1.3 LSTM的基本组成

如下给出了两种LSTM的框架单元表示图,在此我们不去细究图中每个变量的含义,有兴趣的伙伴参考李宏毅老师的教学视频。针对LSTM网络我们选择从公式出发去介绍LSTM。
在这里插入图片描述
首先我们给出几个基本的符号定义:

σ σ σ表示sigmoid函数

h t h_t ht表示 t t t时刻隐藏层的状态值

C t C_t Ct表示 t t t时刻细胞层的状态值

o t , f t , i t o_t, f_t, i_t ot,ft,it依次表示 t t t时刻输出门,遗忘门和输入门的状态值

x t , y t x_t, y_t xt,yt表示 t t t时刻网络的输入和输出

W R W_R WR表示不同的网络权重(不同的权重用不同下标表示)

对于 t t t时刻的网络的block(上图中的单个绿框),其信号输入为 x t x_t xt t − 1 t-1 t1时刻的block的隐藏层输出 h t − 1 h_{t-1} ht1和细胞层输出 C t − 1 C_{t-1} Ct1;信号输出为 h t , C t h_t, C_t ht,Ct C t C_t Ct的计算公式如下所示
C t = f t C t − 1 + i t C t ~ C_t=f_tC_{t-1}+i_t\tilde{C_t} Ct=ftCt1+itCt~

C t ~ = t a n h ( W c [ h t − 1 , x t ] ) \tilde{C_t}=tanh(W_c[h_{t-1},x_t]) Ct~=tanh(Wc[ht1,xt])

h t h_t ht的计算公式如下所示
h t = o t t a n h ( C t ) h_t=o_ttanh(C_t) ht=ottanh(Ct)
其中, t t t时刻输出门,遗忘门和输入门的状态值的计算公式如下所示
o t = σ ( W o [ h t − 1 , x t ] ) o_t = σ(W_{o}[h_{t-1},x_t]) ot=σ(Wo[ht1,xt])

f t = σ ( W f [ h t − 1 , x t ] ) f_t = σ(W_{f}[h_{t-1},x_t]) ft=σ(Wf[ht1,xt])

i t = σ ( W i [ h t − 1 , x t ] ) i_t = σ(W_{i}[h_{t-1},x_t]) it=σ(Wi[ht1,xt])

需要强调的是,现有的一些有关LSTM的框架流程图只能宏观的表示网络的输入输出和大致的流程,于小袁我而言这些流程图对于LSTM的刻画程度并没有公式来的直接和具体,因而小袁还是建议感兴趣的伙伴可以多多钻研钻研公式。

二,对LSTM的两脸懵逼

2.1 懵逼一:这个结构得多大脑洞想的

在上面的讲解中我们已经知道,RNN在梯度更新权重的过程中存在梯度消失问题,那LSTM网络就和小葵花妈妈一样,自然是哪里有问题改哪里。1997年Sepp Hochreiter在提出长短期记忆网络LSTM时,网络中的遗忘门的值 f t = 1 f_t = 1 ft=1,在这篇论文中,作者指出设计输入门和输出门的原因主要是为了解决冲突,原文如下:

  1. Input weight conflict: for simplicity, let us focus on a single additional input weight w j i w_{ji} wji. Assume that the total error can be reduced by switching on unit j j j in response to a certain input, and keeping it active for a long time (until it helps to compute a desired output). Provided i i i is non- zero, since the same incoming weight has to be used for both storing certain inputs and ignoring others, w j i w_{ji} wji will often receive conflicting weight update signals during this time (recall that j j j is linear): these signals will attempt to make w j i w_{ji} wji participate in (1) storing the input (by switching on j j j) and (2) protecting the input (by preventing j j j from being switched off by irrelevant later inputs). This conflict makes learning difficult, and calls for a more context-sensitive mechanism for controlling “write operations” through input weights.
  2. Output weight conflict: assume j j j is switched on and currently stores some previous input. For simplicity, let us focus on a single additional outgoing weight w k j w_{kj} wkj . The same w k j w_{kj} wkj has to be used for both retrieving j j j 's content at certain times and preventing j j j from disturbing k k k at other times. As long as unit j j j is non-zero, w k j w_{kj} wkj will attract conflicting weight update signals generated during sequence processing: these signals will attempt to make w k j w_{kj} wkj participate in (1) accessing the information stored in j j j and — at different times — (2) protecting unit k k k from being perturbed by j j j . For instance, with many tasks there are certain “short time lag errors” that can be reduced in early training stages. However, at later training stages j j j may suddenly start to cause avoidable errors in situations that already seemed under control by attempting to participate in reducing more difficult “long time lag errors”. Again, this conflict makes learning difficult, and calls for a more context-sensitive mechanism for controlling “read operations” through output weights.

就小袁个人理解而言,输入门和输出门是一种对信息的筛选机制,比如阻止 t − 1 t-1 t1 t − 2 t-2 t2时刻的网络的输入 x t − 1 x_{t-1} xt1 x t − 2 x_{t-2} xt2 t t t时刻的细胞状态值 C t C_t Ct的直接影响,则我只需要简单将输入门 i t − 1 i_{t-1} it1 i t − 2 i_{t-2} it2置零。在此,我们不妨假设在任何时刻,如果LSTM的 o t , i t o_t,i_t ot,it恒为1, f t f_t ft恒为1,此时的网络称为退化的LSTM。有:

对于 t t t时刻的网络的block(上图中的单个绿框),其信号输入为 x t x_t xt t − 1 t-1 t1时刻的block的隐藏层输出 h t − 1 h_{t-1} ht1和细胞层输出 C t − 1 C_{t-1} Ct1;信号输出为 h t , C t h_t, C_t ht,Ct C t C_t Ct的计算公式如下所示
C t = C t − 1 + C t ~ C_t=C_{t-1}+\tilde{C_t} Ct=Ct1+Ct~

C t ~ = t a n h ( W c [ h t − 1 , x t ] ) \tilde{C_t}=tanh(W_c[h_{t-1},x_t]) Ct~=tanh(Wc[ht1,xt])

h t h_t ht的计算公式如下所示
h t = t a n h ( C t ) h_t=tanh(C_t) ht=tanh(Ct)
联合上述三个公式易知
h t = t a n h ( C t − 1 + t a n h ( W c [ h t − 1 , x t ] ) ) h_t=tanh(C_{t-1}+tanh(W_c[h_{t-1},x_t])) ht=tanh(Ct1+tanh(Wc[ht1,xt]))
对比此时退化的LSTM神经网络结构和RNN更新方程
h t = t a n h ( W i x t + W h h t − 1 ) h_t=tanh(W^ix_t + W^hh_{t-1}) ht=tanh(Wixt+Whht1)
可以看到相比RNN的结构而言,退化的LSTM引进了细胞的状态层,相比RNN,网络的单个block深了一层。

2.2 懵逼二: 咋就避免序列数据的梯度消失

如1.2所言,递归导数是导致梯度消失的主要因素,因此我们分析下在LSTM中,递归导数的数学表达。首先,LSTM的细胞的状态值更新公式为:
C t = f t C t − 1 + i t C t ~ C_t=f_tC_{t-1}+i_t\tilde{C_t} Ct=ftCt1+itCt~
又在上述公式中, f t , i t , C t ~ f_t,i_t,\tilde{C_t} ft,it,Ct~是关于隐藏状态值 h t − 1 h_{t-1} ht1的函数, h t − 1 h_{t-1} ht1是关于 C t − 1 C_{t-1} Ct1的函数,因此根据链式求导法则,有
∂ C t ∂ C t − 1 = ∂ C t ∂ f t ∂ f t ∂ h t − 1 ∂ h t − 1 ∂ C t − 1 + ∂ C t ∂ C t − 1 + ∂ C t ∂ i t ∂ i t ∂ h t − 1 ∂ h t − 1 ∂ C t − 1 + ∂ C t ∂ C ~ t ∂ C ~ t ∂ h t − 1 ∂ h t − 1 ∂ C t − 1 \frac{\partial C_t}{\partial C_{t-1}}=\frac{\partial C_t}{\partial f_{t}}\frac{\partial f_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial C_{t-1}}+\frac{\partial C_{t}}{\partial C_{t-1}}+\frac{\partial C_t}{\partial i_{t}}\frac{\partial i_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial C_{t-1}}+\frac{\partial C_t}{\partial \tilde C_{t}}\frac{\partial\tilde C_t}{\partial h_{t-1}}\frac{\partial h_{t-1}}{\partial C_{t-1}} Ct1Ct=ftCtht1ftCt1ht1+Ct1Ct+itCtht1itCt1ht1+C~tCtht1C~tCt1ht1
化简有
∂ C t ∂ C t − 1 = C t − 1 σ ′ W f ∗ o t − 1 t a n h ′ ( C t − 1 ) + f t + C ~ t σ ′ W i ∗ o t − 1 t a n h ′ ( C t − 1 ) + i t t a n h ′ W c ∗ o t − 1 t a n h ′ ( C t − 1 ) \frac{\partial C_t}{\partial C_{t-1}}=C_{t-1} σ'W_f*o_{t-1}tanh'(C_{t-1})+f_t+\tilde C_tσ'W_i*o_{t-1}tanh'(C_{t-1})+i_{t}tanh'W_c*o_{t-1}tanh'(C_{t-1}) Ct1Ct=Ct1σWfot1tanh(Ct1)+ft+C~tσWiot1tanh(Ct1)+ittanhWcot1tanh(Ct1)
我们现在将RNN的递归导数列出来,如下所示
∂ h t ∂ h t − 1 = W h f ′ \frac{\partial h_t}{\partial h_{t-1}}=W_hf' ht1ht=Whf
容易看到,LSTM的递归导数的值的大小与时间 t t t有关,即不同时刻的值可以大于1,或者在0~1区间。然而在RNN中,一旦 W h < 4 W_h<4 Wh<4(假设 W h W_h Wh为维度为1),则所有时刻的递归导数的值均小于1.这就使得RNN相比LSTM更易发生梯度消失问题。用weberna的博客的话说:“ In vanilla RNNs, the terms ∂ h t ∂ h t − 1 \frac{\partial h_t}{\partial h_{t-1}} ht1ht will eventually take on a values that are either always above 1 or always in the range [0,1], this is essentially what leads to the vanishing/exploding gradient problem. The terms here, ∂ C t ∂ C t − 1 \frac{\partial C_t}{\partial C_{t-1}} Ct1Ct ,at any time step can take on either values that are greater than 1 or values in the range [0,1]. Thus if we extend to an infinite amount of time steps, it is not guarented that we will end up converging to 0 or infinity (unlike in vanilla RNNs).”

也就是说,LSTM并不能保证能完全避免梯度消失,只是相比与RNN,递归导数中的 f t , o t , i t f_t,o_t,i_t ft,ot,it的值是由数据驱动的,可调整的,因而更容易避免梯度迭代优化算法中的梯度消失问题。

三,博主碎碎念

在LSTM的学习理解过程中,博主觉得比较好的三个学习链接,推荐给大家:

如果你对LSTM的设计思路感兴趣,后者你有其它的understanding或idea,欢迎来私戳博主交流。

博主目前是枚科研秃头怪,如果你也撰写博客或者看到过一些经典算法的好博客,能推荐给我的话我将非常感谢!

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值