长短时记忆网络
(Long Short-Term Memory,LSTM)
不管是我们还是计算机都很难有过目不忘的记忆,当看到一篇长文时,通常只会几下长文中内句话所讲的核心,而一些不太起眼的词汇将会被忘掉。
LSTM网络就是模仿人的这个特点,在计算机处理很多文字时有侧重点的记住具有重要意义的词汇,忘记一些作用不大的词汇。
长短时记忆网络通过不断地调用同一个cell逐次处理时序信息,每阅读一个词汇xt,就会输出一个新的重点记忆信号ht,用来表示当前阅读到的所有内容的整体(具有侧重点的)向量表示。下图就是LSTM网络的内部结构示意图,我们先不提他的内部结构看一下LSTM网络的输入与输出都是什么。
C
t
:
全
部
记
忆
h
t
:
有
侧
重
的
记
忆
C_t\;:\;\;\;\;\;\mathrm{全部记忆}\;\;\;\;\\h_{t\;}\;:\;\mathrm{有侧重的记忆}
Ct:全部记忆ht:有侧重的记忆
xt是被逐个词读取的,相当于从一篇文章的第一次开始一直向后读。
h(t-1) 和 C(t-1) 是前面经过处理后保留下来的记忆,h(t-1) 是有侧重点记忆。
h(t) 和 C(t) 则是通过学习当前的xt后形成的记忆,需要传递给下一cell。
h(t-1)与x(t)作为输入,经过处理后汇总进C(t)当中。
最开始学习LSTM需要做的是把输入输出搞清楚(整体的公式写在在本文的最后面),下面来说一说网络内部的结构:
遗忘门
在上图LSTM网络结构图中,函数ft代表着遗忘门(f --> forget),控制着过去记忆融合进C(t)的比例。这里使用sigmoid函数调节信息的重要程度。
f
t
=
s
i
g
m
o
i
d
(
W
f
X
t
+
V
f
H
t
−
1
+
b
f
)
f_t\;=\;sigmoid(\;W_fX_t\;+\;V_fH_{t-1}\;+\;b_f\;)
ft=sigmoid(WfXt+VfHt−1+bf)
输入门
在上图LSTM网络结构图中,函数it代表着输入门(i–>input),控制着有多少输入信号被融合。
i
t
=
s
i
g
m
o
i
d
(
W
i
X
t
+
V
i
H
t
−
1
+
b
i
)
i_t\;=\;sigmoid(\;W_iX_t\;+\;V_iH_{t-1}\;+\;b_i\;)
it=sigmoid(WiXt+ViHt−1+bi)
图中的tanh函数处,在这里我们将他叫做单元状态gt,单元状态中以tanh函数对输入数据作处理来调节网络。
g
t
=
tan
h
(
W
g
X
t
+
V
g
H
t
−
1
+
b
g
)
g_t\;=\;\tan h(\;W_gX_t\;+\;V_gH_{t-1}\;+\;b_g\;)
gt=tanh(WgXt+VgHt−1+bg)
输入门最后在汇总的时候需要将it的输出与gt的输出相乘,下面这个图展现了两个门是如何控制遗忘和记忆的。
C
t
=
f
t
∗
C
t
−
1
+
i
t
∗
g
t
C_t\;=\;f_t\;\ast\;C_{t-1}\;\;+\:\;i_t\;\ast\;g_t
Ct=ft∗Ct−1+it∗gt
输出门
控制着最终输出多少记忆(Ot -->output)
输出们按比重学习、筛选输入信息然后汇入ht中传给下一个cell;这里需要对h(t-1)处理后再与tanh(Ct)相乘。
o
t
=
s
i
g
m
o
i
d
(
W
o
X
t
+
V
o
H
t
−
1
+
b
o
)
o_t\;=\;sigmoid(\;W_oX_t\;+\;V_oH_{t-1}\;+\;b_o\;)
ot=sigmoid(WoXt+VoHt−1+bo)
h
t
=
o
t
+
tan
h
(
C
t
)
h_t\;=\;o_t\;+\;\tan h(C_t)
ht=ot+tanh(Ct)
以上就是单层LSTM网络的大致结构。
其实单层LSTM把输入和输出关系捋清楚之后,似乎就很容易入门了。
C
t
=
f
t
∗
C
t
−
1
+
i
t
∗
g
t
h
t
=
o
t
+
tan
h
(
C
t
)
{
f
t
=
s
i
g
m
o
i
d
(
W
f
X
t
+
V
f
H
t
−
1
+
b
f
)
i
t
=
s
i
g
m
o
i
d
(
W
i
X
t
+
V
i
H
t
−
1
+
b
i
)
o
t
=
s
i
g
m
o
i
d
(
W
o
X
t
+
V
o
H
t
−
1
+
b
o
)
g
t
=
tan
h
(
W
g
X
t
+
V
g
H
t
−
1
+
b
g
)
C_t\;=\;f_t\;\ast\;C_{t-1}\;\;+\:\;i_t\;\ast\;g_t\\h_t\;=\;o_t\;+\;\tan h(C_t)\;\;\;\;\;\;\;\;\;\;\;\\\\\\\left\{\begin{array}{l}f_t\;=\;sigmoid(\;W_fX_t\;+\;V_fH_{t-1}\;+\;b_f\;)\\i_t\;=\;sigmoid(\;W_iX_t\;+\;V_iH_{t-1}\;+\;b_i\;)\\o_t\;=\;sigmoid(\;W_oX_t\;+\;V_oH_{t-1}\;+\;b_o\;)\\g_t\;=\;\tan h(\;W_gX_t\;+\;V_gH_{t-1}\;+\;b_g\;)\end{array}\right.
Ct=ft∗Ct−1+it∗gtht=ot+tanh(Ct)⎩⎪⎪⎨⎪⎪⎧ft=sigmoid(WfXt+VfHt−1+bf)it=sigmoid(WiXt+ViHt−1+bi)ot=sigmoid(WoXt+VoHt−1+bo)gt=tanh(WgXt+VgHt−1+bg)
为了表示输出的Ct和ht,才创建的其他别具职能的函数。其中输入x(t-1)和H(t-1)大多在下面的职能函数里面被带入。
参考:
入门学习并附带课程视频的paddle集训营课程
这个链接是能够加入到课程的AIstudio平台
对LSTM的理解下面的这篇博文介绍的比较详细易懂
对LSTM的理解