LSTM是传统RNN网络的扩展,其核心结构是其cell单元,网上LSTM的相关资料繁多,质量参差不齐,下面主要结合LSTM神经网络的详细推导和 Christopher Olah的blog两篇文章中的内容进行说明。主要介绍网络如何计算,为何这么算先不展开:)。前者一副图加上29个公式,简洁明了;后者娓娓道来,适合初学者。
首先是LSTM cell最常见的结构图:
- 1
- 2
这是变形的版本(找不到更清晰的版本了),其中输入门控制输入(新记忆)的输入幅度;遗忘门控制之前记忆状态的输入幅度;输出门控制最终记忆的输出幅度。图中的三角形其实就是乘法符号。
- 1
- 2
t时刻cell的Input:
1.由当前输入Xt
2.前一时刻cell的输出ht-1
3.前一时刻cell的状态ct-1(可以理解为计算ht-1过程中的中间值)
t时刻cell的3个控制门Gate,值域[0,1](改进的GRU的cell将输入门和遗忘门合并为Update门):
1.输入门it
2.遗忘门ft
3.输出门ot
计算过程如下(请对照上面第二个结构图):
step 1.1 输入门it
step 1.2 及其控制的新记忆Ct波浪线:)(如下图)
W是其对应的权重矩阵,b为偏置。黄色的框内是不同的激活函数。其实这两个运算可以等效为两层并行的神经网络。
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
step 1.3 遗忘门ft (控制对于之前输入记忆ct-1的遗忘程度)(如下图)
其中,step1.1、1.2和1.3是可以并行计算的,输入都是当前输入Xt 和 前一时刻cell的输出ht-1
- 1
- 2
- 3
step 2 当前t时刻cell的状态Ct(由step 1计算的三个结果得到)
- 1
- 2
step 3 输出门Ot及其控制的t时刻cell的输出ht
- 1
- 2
step 4 信号xt通过ht的输出:
- 1
- 2
以上其实是lstm的前向传播过程,反向传播求解梯度及参数更新具体参考LSTM神经网络的详细推导