什么是LSTM
LSTM是SimpleRNN的变体,用来解决RNN中梯度消失的问题,相比普通的RNN,LATM在长序列中表现更好。
LSTM的原理
RNN通常只有一个传递状态 h t h^t ht, 而LSTM通常有两个状态: c t c^t ct和 h t h^t ht;
可以将 c t c^t ct看作一个传送带,包含了RNN网络中每个时间结点的信息,常输出的 c t c^t ct是上一个状态传过来的 c t − 1 c^{t-1} ct−1加上一些数值。
而 h t h^t ht通常在不同结点差别很大。
z = t a n h ( W ∗ ( x t , h t − 1 ) ) z=tanh(W*(x^t,h^{t-1})) z=tanh(W∗(xt,ht−1))
z i = σ ( W i ∗ ( x t , h t − 1 ) ) z^i=\sigma(W^i*(x^t,h^{t-1})) zi=σ(Wi∗(xt,ht−1))
z f = σ ( W f ∗ ( x t , h t − 1 ) ) z^f=\sigma(W^f*(x^t,h^{t-1})) zf=σ(Wf∗(xt,ht−1))
z o = σ ( W o ∗ ( x t , h t − 1 ) ) z^o=\sigma(W^o*(x^t,h^{t-1})) zo=σ(Wo∗(xt,ht−1))
其中, z i , z f , z o z^i,z^f,z^o zi,zf,zo是是由拼接向量乘以权重矩阵之后,再通过一个 s i g m o i d sigmoid sigmoid激活函数转换成0到1之间的数值,来作为一种门控状态。而 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-JYaWITEv-1619274258999)(https://www.zhihu.com/equation?tex=z)] 则是将结果通过一个 t a n h tanh tanh激活函数将转换成-1到1之间的值(这里使用 t a n h tanh tanh是因为这里是将其做为输入数据,而不是门控信号)。
⊙
\odot
⊙是Hadamard Product,也就是操作矩阵中对应的元素相乘,因此要求两个相乘矩阵是同型的。
⊕
\oplus
⊕ 则代表进行矩阵加法
LSTM内部主要有三个阶段:
-
忘记阶段。这个阶段主要是对上一个节点传进来的输入进行选择性忘记。简单来说就是会 “忘记不重要的,记住重要的”。
具体来说是通过计算得到的 z f z^f zf来作为忘记门控,来控制上一个状态的 c t − 1 c^{t-1} ct−1 哪些需要留哪些需要忘。
-
选择记忆阶段。这个阶段将这个阶段的输入有选择性地进行“记忆”。主要是会对输入 x t x^t xt 进行选择记忆。哪些重要则着重记录下来,哪些不重要,则少记一些。当前的输入内容由前面计算得到的 z z z表示。而选择的门控信号则是由 z i z^i zi来进行控制。
将上面两步得到的结果相加,即可得到传输给下一个状态的 c t c^t ct 。也就是上图中的第一个公式。
- 输出阶段。这个阶段将决定哪些将会被当成当前状态的输出。主要是通过 z o z^o zo来进行控制的。并且还对上一阶段得到的 z o z^o zo进行了放缩(通过一个tanh激活函数进行变化)。
与普通RNN类似,输出 y t y^t yt 往往最终也是通过 h t h^t ht 变化得到。
优缺点
- 优点:解决了SimpleRNN梯度消失的问题,可以处理long-term sequence
- 缺点:计算复杂度高,想想谷歌翻译也只是7-8层LSTM就知道了;自己跑代码也有明显的感觉,比较慢。
参考文献
- LSTM原理与实践,原来如此简单
- [人人都能看懂的LSTM](