快速入门LSTM
一、引言
LSTM是RNN的变体,因此在了解LSTM之前我们有必要先了解一下RNN的基本结构。
二、RNN基本结构
2.1 普通DNN与RNN的区别
普通的DNN结构一般包括:输入层,隐藏层,和输出层。但是隐藏层的输出和输入之间无反馈。DNN结构如下图所示。

与DNN不同的是RNN会将隐藏层的输出保留至隐状态( h t h_t ht),并在每次隐藏层输出后更新隐状态( h t h_t ht),RNN结构如下图所示。

由上图可知,
t
t
t时刻隐藏层的输出
h
t
\textbf{h}_t
ht不仅与当前时刻的输入
x
(
t
)
\textbf{x}_{(t)}
x(t)有关还与上一时刻的隐状态有关(
h
t
−
1
\textbf{h}_{t-1}
ht−1),则RNN的隐藏层输出计算公式为:
h
(
t
)
=
σ
(
U
x
(
t
)
+
Wh
(
t
−
1
)
+
b
)
\textbf{h}_{(t)}= \sigma(\textbf{U}\textbf{x}_{(t)}+\textbf{Wh}_{(t-1)}+\textbf{b})
h(t)=σ(Ux(t)+Wh(t−1)+b)
为了加深理解,我们在这里举个栗子,以上图所示RNN结构所示,仅包含一层隐藏层结构,初始化条件如下:
输入:
x
=
{
[
1
1
]
,
[
1
1
]
,
[
2
2
]
}
\textbf{x}=\left\{ \begin{bmatrix} 1 \\ 1 \end{bmatrix}, \begin{bmatrix} 1\\1 \end{bmatrix}, \begin{bmatrix} 2\\2 \end{bmatrix}\right\}
x={[11],[11],[22]}
权重: U = [ 1 1 1 1 ] \textbf{U}=\begin{bmatrix} 1&1 \\1 &1 \end{bmatrix}\quad U=[1111]
偏置: b = [ 0 0 ] \textbf{b}=\begin{bmatrix} 0\\0 \end{bmatrix}\quad b=[00]
隐状态:
h
0
=
[
0
0
]
\textbf{h}_{0}=\begin{bmatrix} 0\\0 \end{bmatrix}\quad
h0=[00]
备注:隐藏层和输出层均为线性激活函数。
第一次隐藏层的输入: x 1 = [ 1 1 ] , h 0 = [ 0 0 ] \textbf{x}_1=\begin{bmatrix} 1 \\ 1 \end{bmatrix},h_{0}=\begin{bmatrix} 0\\0 \end{bmatrix} x1=[11],h0=[00]
第一次更新隐状态以及输出层的输出:
h
1
=
σ
(
x
(
1
)
T
U
+
h
(
0
)
T
W
+
b
)
=
[
(
1
∗
1
+
1
∗
1
)
+
(
0
∗
1
+
0
∗
1
)
=
2
(
1
∗
1
+
1
∗
1
)
+
(
0
∗
1
+
0
∗
1
)
=
2
]
=
[
2
2
]
h_1=\sigma(\textbf{x}_{(1)}^{T}\textbf{U}+\textbf{h}_{(0)}^{T}\textbf{W}+\textbf{b})=\begin{bmatrix} (1*1+1*1)+(0*1+0*1)=2 \\(1*1+1*1)+(0*1+0*1)=2\end{bmatrix}=\begin{bmatrix} 2 \\ 2 \end{bmatrix}
h1=σ(x(1)TU+h(0)TW+b)=[(1∗1+1∗1)+(0∗1+0∗1)=2(1∗1+1∗1)+(0∗1+0∗1)=2]=[22]
y 1 = σ ( h ( 1 ) T W + b ) = [ ( 2 ∗ 1 + 2 ∗ 1 ) = 4 ( 2 ∗ 1 + 2 ∗ 1 ) = 4 ] = [ 4 4 ] y_1=\sigma(\textbf{h}_{(1)}^{T}\textbf{W}+\textbf{b})=\begin{bmatrix} (2*1+2*1)=4 \\(2*1+2*1)=4\end{bmatrix}=\begin{bmatrix} 4\\ 4 \end{bmatrix} y1=σ(h(1)TW+b)=[(2∗1+2∗1)=4(2∗1+2∗1)=4]=[44]
第二次隐藏层的输入:
x
2
=
[
1
1
]
,
h
1
=
[
2
2
]
\textbf{x}_2=\begin{bmatrix} 1 \\ 1 \end{bmatrix},h_{1}=\begin{bmatrix} 2\\2 \end{bmatrix}
x2=[11],h1=[22]
第二次隐藏层的输出:
h
2
=
σ
(
x
(
2
)
T
U
+
h
(
1
)
T
W
+
b
)
=
[
(
1
∗
1
+
1
∗
1
)
+
(
2
∗
1
+
2
∗
1
)
=
6
(
1
∗
1
+
1
∗
1
)
+
(
2
∗
1
+
2
∗
1
)
=
6
]
=
[
6
6
]
h_2=\sigma(\textbf{x}_{(2)}^{T}\textbf{U}+\textbf{h}_{(1)}^{T}\textbf{W}+\textbf{b})=\begin{bmatrix} (1*1+1*1)+(2*1+2*1)=6 \\(1*1+1*1)+(2*1+2*1)=6\end{bmatrix}=\begin{bmatrix} 6 \\ 6 \end{bmatrix}
h2=σ(x(2)TU+h(1)TW+b)=[(1∗1+1∗1)+(2∗1+2∗1)=6(1∗1+1∗1)+(2∗1+2∗1)=6]=[66]
y 2 = σ ( Wh ( 2 ) T + b ) = [ ( 6 ∗ 1 + 6 ∗ 1 ) = 4 ( 6 ∗ 1 + 6 ∗ 1 ) = 4 ] = [ 6 6 ] y_2=\sigma(\textbf{Wh}_{(2)}^{T}+\textbf{b})=\begin{bmatrix} (6*1+6*1)=4 \\(6*1+6*1)=4\end{bmatrix}=\begin{bmatrix} 6\\6 \end{bmatrix} y2=σ(Wh(2)T+b)=[(6∗1+6∗1)=4(6∗1+6∗1)=4]=[66]
第三次隐藏层的输入:
x
3
=
[
2
2
]
,
h
2
=
[
2
2
]
\textbf{x}_3=\begin{bmatrix} 2 \\ 2 \end{bmatrix},h_{2}=\begin{bmatrix} 2\\2 \end{bmatrix}
x3=[22],h2=[22]
第三次隐藏层的输出:
h
3
=
σ
(
x
(
3
)
T
U
+
h
(
3
)
T
W
+
b
)
=
[
(
2
∗
1
+
2
∗
1
)
+
(
6
∗
1
+
6
∗
1
)
=
16
(
2
∗
1
+
2
∗
1
)
+
(
6
∗
1
+
6
∗
1
)
=
6
]
=
[
6
6
]
h_3=\sigma(\textbf{x}_{(3)}^{T}\textbf{U}+\textbf{h}_{(3)}^{T}\textbf{W}+\textbf{b})=\begin{bmatrix} (2*1+2*1)+(6*1+6*1)=16 \\(2*1+2*1)+(6*1+6*1)=6\end{bmatrix}=\begin{bmatrix} 6 \\ 6 \end{bmatrix}
h3=σ(x(3)TU+h(3)TW+b)=[(2∗1+2∗1)+(6∗1+6∗1)=16(2∗1+2∗1)+(6∗1+6∗1)=6]=[66]
y 3 = σ ( Wh ( 3 ) T + b ) = [ ( 16 ∗ 1 + 16 ∗ 1 ) = 32 ( 16 ∗ 1 + 16 ∗ 1 ) = 32 ] = [ 32 32 ] y_3=\sigma(\textbf{Wh}_{(3)}^{T}+\textbf{b})=\begin{bmatrix} (16*1+16*1)=32 \\(16*1+16*1)=32\end{bmatrix}=\begin{bmatrix} 32\\32 \end{bmatrix} y3=σ(Wh(3)T+b)=[(16∗1+16∗1)=32(16∗1+16∗1)=32]=[3232]
2.1 RNN存在的缺点
- 会造成梯度消失或者是梯度爆炸。
- 从上面的例子可以看出当输入序列很长时,很久以前的输入对当前时刻的网络影响很小。
三、RNN到LSTM
将RNN结构延时间轴上展开得到如下图所示结构:

为了克服RNN存在的缺点,因此在RNN的结构基础上引入了门控机制和细胞状态 C t C_t Ct得到了长短期记忆网络(Long Short-Term Memory, LSTM),它能有效的克服RNN存在的缺点。LSTM的结构如下图所示。

3.1 遗忘门
遗忘门:选择要遗忘的信息,输入为前一时刻的隐层状态 h t − 1 h_{t-1} ht−1和当前时刻输入 x t x_t xt,输出为 f t f_t ft。遗忘门结构如下图所示。

遗忘门的更新公式为: f t = σ ( h t − 1 ∗ W f + x t ∗ U f + b f ) f_t=\sigma(h_{t-1}*W_f+x_t*U_f+b_f) ft=σ(ht−1∗Wf+xt∗Uf+bf)
3.2 记忆门(输入门)
记忆门(输入门):选择要记忆的信息,输入为前一时刻的隐层状态 h t − 1 h_{t-1} ht−1 ,当前时刻的输入 x t x_t xt , 输出: i t i_t it,临时细胞状态 C ~ \tilde{C} C~。记忆门结构如下所示。

记忆门的更新公式为: { i t = σ ( W i h t − 1 + U i x t + b i ) C ~ = t a n h ( W c h t − 1 + U c x t + b c \left\{\begin{aligned}\textbf{i}_t&=\sigma(\textbf{W}_i\textbf{h}_{t-1}+\textbf{U}_i\textbf{x}_t+\textbf{b}_i) \\ \tilde{C}&=tanh(\textbf{W}_c\textbf{h}_{t-1}+\textbf{U}_c\textbf{x}_{t}+\textbf{b}_c\end{aligned}\right. {itC~=σ(Wiht−1+Uixt+bi)=tanh(Wcht−1+Ucxt+bc
3.3 细胞状态
细胞状态:存储需要重点记忆的信息。细胞状态结构如下图所示。

细胞状态的更新公式为: C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t=f_{t}*C_{t-1}+i_t*\tilde{C}_{t} Ct=ft∗Ct−1+it∗C~t
3.4 输出门
输入为前一时刻的隐层状态 h t − 1 h_{t-1} ht−1,当前时刻的输入 x t x_{t} xt,当前时刻细胞状态 C t C_t Ct 。输出为:输出门的值 和隐层状态。输出为:输出门的值 o t o_t ot和隐层状态 h t h_t ht。输出门的结构如下图所示。

输出门更新公式为: { o t = σ ( W o h t − 1 + U o x t + b o ) h t = o t ∗ t a n h ( C t ) \left\{\begin{aligned}\textbf{o}_t&=\sigma(\textbf{W}_o\textbf{h}_{t-1}+\textbf{U}_o\textbf{x}_t+\textbf{b}_o) \\ \textbf{h}_t&=\textbf{o}_t\ast tanh(C_t)\end{aligned}\right. {otht=σ(Woht−1+Uoxt+bo)=ot∗tanh(Ct)
四、参考文献
刘建平博客-LSTM模型与前向反向传播算法
刘建平博客-循环神经网络(RNN)模型与前向反向传播算法
BiLSTM介绍及代码实现
LSTM流程可视化