网络上大部分介绍LSTM的博客都是先给出一张图,然后进行分析。我们今天反其道行之,分析LSTM先不看图。看图反而会吸引你太多的注意力,无法很好的参透其中的奥妙。
咱们首先明确一下,LSTM有哪几个部分组成:
a) 三个门:输入门、输出门、遗忘门,作用暂且不表,只需要知道门可以控制信息的过滤;
b) Cell state:中文翻译为细胞态,感觉并不是很准确,索性就不翻译了。Cell state变量存储的是当前时刻t及其前面所有时刻的混合信息,也就是说,在LSTM中,信息的记忆与维护都是通过cell state变量的。所以,cell state非常重要!
c) hidden state:当前时刻的隐状态。刚接触LSTM时,一直搞不懂hidden state与cell state的作用。因为,在一般的RNN范式中,只有hidden_state这个概念,为啥LSTM中又多出了个Cell state?在b)中已经提到,cell state维护的是到目前时刻t为止,所有积累的信息。看完具体公式会发现,LSTM中的hidden_state其实就是cell state的一种过滤之后的信息,更关注当前时间点的输出结果。LSTM的hidden state其实就是当前时刻的output。
d) 输入
x
t
x_t
xt。当前时间点的输入。不过,需要注意的是,在LSTM每一个时间步中,最终输入其实由
x
t
x_t
xt及上一时刻隐状态
h
t
−
1
h_{t-1}
ht−1组成。
对上述几个部分有个基本概念后,下面再看具体的公式,应该会清晰很多。首先,咱们来看下三个门的计算原理:
输入门:
i
t
=
σ
(
W
i
[
x
t
,
h
t
−
1
]
+
b
i
)
i_t=\sigma(W_i[x_t,h_{t-1}]+b_i)
it=σ(Wi[xt,ht−1]+bi)
输出门:
o
t
=
σ
(
W
o
[
x
t
,
h
t
−
1
]
+
b
o
)
o_t=\sigma(W_o[x_t,h_{t-1}]+b_o)
ot=σ(Wo[xt,ht−1]+bo)
遗忘门:
f
t
=
σ
(
W
f
[
x
t
,
h
t
−
1
]
+
b
f
)
f_t=\sigma(W_f[x_t,h_{t-1}]+b_f)
ft=σ(Wf[xt,ht−1]+bf)
很明显的看出来,三个门的形式一模一样,输入也是一模一样,都是当前时间点输入
x
t
x_t
xt及前一时刻隐状态
h
t
−
1
h_{t-1}
ht−1两者的concat向量。唯一不同的可能就是在具体初始化时,权重被设置的不同。
简单看完门如何设置之后,下面就来看看重中之重:Cell state如何通过门进行更新:
当前时间
t
t
t的Cell state
C
t
C_t
Ct有两部分组成,一是继承自上一时刻的Cell state
C
t
−
1
C_{t-1}
Ct−1,二是来源于当前时刻
t
t
t的输入。
在继承上一时刻 C t − 1 C_{t-1} Ct−1时,LSTM通过遗忘门 f t f_t ft去控制到底继承多少历史信息,即 C t − p a r t 1 = f t ⨀ C t − 1 C_{t-part1}=f_t\bigodot C_{t-1} Ct−part1=ft⨀Ct−1。
将当前时刻 t t t的输入信息融入进 C t C_t Ct时,LSTM通过输入门 i t i_t it来控制到底有多少输入会被加进 C t C_t Ct中。没有输入门的话,所有的输入都会被加进 C t C_t Ct中,即 t a n h ( W c [ x t , h t − 1 ] + b t ) tanh(W_c[x_t,h_{t-1}]+b_t) tanh(Wc[xt,ht−1]+bt)。但是,输入门的作用就是有选择性的而不是一股脑将输入都加进 C t C_t Ct中。所以, C t − p a r t 2 = i t ⨀ t a n h ( W c [ x t , h t − 1 ] + b c ) C_{t-part2}=i_t\bigodot tanh(W_c[x_t,h_{t-1}]+b_c) Ct−part2=it⨀tanh(Wc[xt,ht−1]+bc)。
最终,当前时刻 t t t的 C t = C t − p a r t 1 + c t − p a r t 2 = f t ⨀ C t − 1 + i t ⨀ t a n h ( W c [ x t , h t − 1 ] + b c ) C_t=C_{t-part1}+c_{t-part2}=f_t\bigodot C_{t-1} + i_t\bigodot tanh(W_c[x_t,h_{t-1}]+b_c) Ct=Ct−part1+ct−part2=ft⨀Ct−1+it⨀tanh(Wc[xt,ht−1]+bc)。公式看起来很复杂是不是,但是只要将上面组成 C t C_t Ct的两部分分开来分析,这个公式就没那么复杂了。
截止到这里,
C
t
C_t
Ct已经完成了当前时刻
t
t
t的信息更新,那么当前时刻的输出是什么呢?LSTM中的输出是根据
C
t
C_t
Ct,并利用输出门去控制哪些信息作为输出的,即
h
t
=
o
t
⨀
t
a
n
h
(
C
t
)
h_t=o_t\bigodot tanh(C_t)
ht=ot⨀tanh(Ct)。也就是说,LSTM每一时刻输出的基础是
C
t
C_t
Ct,只不过套了一层激活函数,并且利用输出门控制了一下信息流出。
值得注意的是,
⨀
\bigodot
⨀表示element-wise product。所以所有门的维度和cell state变量是一致的。
至此,LSTM的基本概念已经介绍完了,下面再来看这张经典的网络结构图,会不会清晰很多。