LSTM的全称是Long Short Term Memory,它是具有记忆长短期信息的能力的神经网络,是一种改进之后的循环神经网络。提出的动机是为了解决普通RNN网络的长期依赖问题(具体细节直接搜索 )。原始 RNN 的隐藏层只有一个状态,即 h h h,它对于短期的输入非常敏感。LSTM再增加一个状态,即 C C C,让它来保存长期的状态,称为单元状态(cell state)。
RNN是一个链式结构,每个时间片使用的是相同的参数。下面是典型的网络结构图:
LSTM和普通的RNN结构不同,典型网络结构如下图:
上图中每个黄色方框表示一个神经网络层,由权值,偏置以及激活函数组成;每个粉色圆圈表示元素级别操作;箭头表示向量流向;相交的箭头表示向量的拼接;分叉的箭头表示向量的复制。
在 t 时刻,LSTM 的输入有三个:当前时刻网络的输入值
x
t
x_t
xt、上一时刻 LSTM 的输出值
h
t
−
1
h_{t-1}
ht−1、以及上一时刻的单元状态
C
t
−
1
C_{t-1}
Ct−1;LSTM 的输出有两个:当前时刻 LSTM 输出值
h
t
h_t
ht、和当前时刻的单元状态
C
t
C_t
Ct.
本文将对LSTM的整体结构分部分展开。
1、下面是LSTM最重要的部分,即单元状态(cell state):
其中:
C
t
=
f
t
.
C
t
−
1
+
i
t
.
C
~
t
C_t=f_t.C_{t-1}+i_t.\widetilde{C}_t
Ct=ft.Ct−1+it.C
t
由上一次的单元状态
C
t
−
1
C_{t-1}
Ct−1 按元素乘以遗忘门
f
t
f_t
ft,再用当前输入的单元状态
C
t
C_t
Ct 按元素乘以输入门
i
t
i_t
it,再将两个积加和:这样,就可以把当前的记忆
C
t
C_t
Ct 和长期的记忆
C
t
−
1
C_{t-1}
Ct−1 组合在一起,形成了新的单元状态
C
t
C_t
Ct。
2、遗忘门
单元状态计算公式中的
f
t
f_t
ft 叫做遗忘门(如下图所示),表示
C
t
−
1
C_{t-1}
Ct−1的哪些特征被用于计算
C
t
C_t
Ct 。
f
t
f_t
ft是一个向量,向量的每个元素均位于
[
0
−
1
]
[0-1]
[0−1]范围内。通常我们使用
s
i
g
m
o
i
d
sigmoid
sigmoid作为激活函数,
s
i
g
m
o
i
d
sigmoid
sigmoid 的输出是一个介于
[
0
−
1
]
[0-1]
[0−1] 区间内的值,但是当你观察一个训练好的LSTM时,你会发现门的值绝大多数都非常接近0或者1,其余的值少之又少。其中
⊗
\otimes
⊗ 是LSTM最重要的门机制,表示
f
t
f_t
ft和
C
t
−
1
C_{t-1}
Ct−1 之间的单位乘的关系。
其中:
f
t
=
σ
(
W
f
.
[
h
t
−
1
,
x
t
]
+
b
f
)
f_t=\sigma(W_f.[h_{t-1},x_t]+b_f)
ft=σ(Wf.[ht−1,xt]+bf)
3、输入门
如下图所示, C ~ t \widetilde{C}_t C t 表示单元状态更新值,由输入数据 x t x_t xt 和隐节点 h t − 1 h_{t-1} ht−1 经由一个神经网络层得到,单元状态更新值的激活函数通常使用 t a n h tanh tanh。 i t i_t it 叫做输入门,同 f t f_t ft一样也是一个元素介于 [ 0 − 1 ] [0-1] [0−1]区间内的向量,同样由 x t x_t xt 和 h t − 1 h_{t-1} ht−1经由 s i g m o i d sigmoid sigmoid 激活函数计算而成。
其中:
i
t
=
σ
(
W
i
.
[
h
t
−
1
,
x
t
]
+
b
i
)
i_t=\sigma(W_i.[h_{t-1},x_t]+b_i)
it=σ(Wi.[ht−1,xt]+bi)
C
~
t
=
t
a
n
h
(
W
C
.
[
h
t
−
1
,
x
t
]
+
b
C
)
\widetilde{C}_t=tanh(W_C.[h_{t-1},x_t]+b_C)
C
t=tanh(WC.[ht−1,xt]+bC)
i
t
i_t
it用于控制
C
~
t
\widetilde{C}_t
C
t的哪些特征用于更新
C
t
C_t
Ct ,使用方式和
f
t
f_t
ft 相同(如下图)。
4、输出门
最后,为了计算预测值
y
^
t
\hat{y}_t
y^t 和生成下个时间片完整的输入,我们需要计算隐节点的输出
h
t
h_t
ht( 如下图)。
其中:
o
t
=
σ
(
W
o
.
[
h
t
−
1
,
x
t
]
+
b
o
)
o_t=\sigma(W_o.[h_{t-1},x_t]+b_o)
ot=σ(Wo.[ht−1,xt]+bo)
h
t
=
o
t
.
t
a
n
h
(
C
t
)
h_t=o_t.tanh(C_t)
ht=ot.tanh(Ct)
h
t
h_t
ht 由输出门
o
t
o_t
ot 和单元状态
C
t
C_t
Ct 得到,其中
o
t
o_t
ot 的计算方式和
f
t
f_t
ft 以及
i
t
i_t
it 相同。
参考文章:
https://zhuanlan.zhihu.com/p/42717426
https://blog.csdn.net/qq_31278903/article/details/88690959