1 LSTM概述
RNN给神经网络加入了处理时间的能力,而传统的RNN会面临梯度消失(爆炸)的问题RNN vs LSTM: Vanishing Gradients,传递的时间信息也会越来越弱。给RNN引入长时记忆至关重要。因此有了Long Short Term Memory(LSTM)。
常见的LSTM结构如下图所示:
xt
x
t
为每个时间步的输入数据,
ht
h
t
为每个时间步的输出,中间的
ct
c
t
为单元间的长时记忆。
注意!该图容易产生误解!!此图是将LSTM按照时间维度进行了展开,实际上同一个时刻只有一个LSTM单元。
即如下图所示:
2 lstm 公式
每个LSTM单元用三个门来决定保留的信息,LSTM计算门和信息有6个公式,我们将公式罗列如下,并在下一章节详细分析这6个公式。
遗忘门(forget gate)
遗忘门 ft f t 对上次单元状态 c(t−1) c ( t − 1 ) 进行选择,它决定了上一时刻的单元状态 ct−1 c t − 1 有多少保留到当前时刻 ct c t
ft=σ(wf⋅[h(t−1),xt]+bf)(1) (1) f t = σ ( w f ⋅ [ h ( t − 1 ) , x t ] + b f )输入门(input gate)
它决定了当前时刻网络的输入 xt x t 有多少保存到单元状态 ct c t
it=σ(wi⋅[h(t−1),xt]+bi)(2) (2) i t = σ ( w i ⋅ [ h ( t − 1 ) , x t ] + b i )输出门(output gate)
控制单元状态 ct c t 有多少输出到 LSTM 的当前输出值 ht h t
ot=σ(wo⋅[h(t−1),xt]+bo)(3) (3) o t = σ ( w o ⋅ [ h ( t − 1 ) , x t ] + b o )输入信息 (Ct)̃ ( C t ) ̃
(Ct)̃ =tanh(wc⋅[h(t−1),xt]+bc)(4) (4) ( C t ) ̃ = t a n h ( w c ⋅ [ h ( t − 1 ) , x t ] + b c )本次单元状态 ct c t
本次单元状态由历史记忆和本次输入共同决定。
ct=ft⋅c(t−1)+it⋅(Ct)̃ (5) (5) c t = f t ⋅ c ( t − 1 ) + i t ⋅ ( C t ) ̃最终输出 ht h t
ht=ot⋅tanh(Ct)(6) (6) h t = o t ⋅ t a n h ( C t )
其中:
xt x t 为本时刻的数据
ct−1 c t − 1 为上个时刻单元的状态,保留着历史的记忆
Ct̃ C t ̃ 为本时刻的输入信息,代表着当前数据的信息
ct c t 为本时刻单元的状态
ht h t 为本时刻单元的输出
每个时刻处理时将 xt x t 与 ht−1 h t − 1 两个矩阵直接拼接在一起 [ht−1,xt] [ h t − 1 , x t ] 作为输入,加上适当的权重 w w 和偏置
3 LSTM cell分析
我们把单个LSTM的cell拎出来进行详细分析,将会发现其实LSTM很简单!
我们对照上图分别对LSTM单元的三个门进行分析:
从左到右观察三个紫色方框:
- 遗忘门对上次单元状态 ct−1 c t − 1 进行选择性遗忘,得到的输出为历史信息的记忆,如公式 (1) (1) 。
- 输入门对当前输入状态 Ct̃ C t ̃ 进行选择性输入,得到的输出为本次输入信息,如公式 (2) (2) 。
- 历史信息的记忆与本次输入的信息求和即是本次单元状态 ct c t ,如公式 (5) (5)
- 输出门对本次单元状态进行选择性输出,得到了最终的输出 ht h t ,如公式 (3) (3)
4 致谢
4.1 图
- 本文图1和4来自于知乎用户EDU GUO在该问题下的回答。图4中注释为本人添加。
- 图2和图3来自于 Christopher Olah 的博文。
4.2 参考
本文主要参考了以下博文:
1. 详解 LSTM:https://www.jianshu.com/p/dcec3f07d3b5
2. Understanding LSTM Networks:http://colah.github.io/posts/2015-08-Understanding-LSTMs/
3. 知乎问题,LSTM神经网络输入输出究竟是怎样的?:https://www.zhihu.com/question/41949741这里写链接内容