LSTM 通过刻意的设计来避免长期依赖问题。"记住"长期的信息在实践中是 LSTM 的默认行为,而非需要付出很大代价才能获得的能力!
之前的所有RNN都具有一种重复神经网络模块的链式的形式。在标准的 RNN 中,这个重复的模块只有一个非常简单的结构,例如一个 tanh/sigmoid激活函数构成的隐藏层。
LSTM 同样是这样的结构,但是重复的模块拥有一个不同的结构。不同于单一神经网络层,这里是有四个门,以一种非常特殊的方式进行交互。
Traditional LSTM结构
这里我们考虑一个完整复杂的LSTM,假设每一个时刻 t t t都有输出 y t y_t yt。为简略,图中暂时只标出状态和激活函数,暂时没有标出可学习的网络参数矩阵。
这里 c t c_{t} ct是cell state, m t m_{t} mt是hidden state, i t i_{t} it是input gate, f t f_{t} ft是forget gate, o t o_{t} ot是output gate. 这里的 c t c_{t} ct是专门进行线性的循环信息传递的内部状态,然后非线性地输出信息给隐藏层的外部状态 h t , m t h_{t}, m_{t} ht,mt。这里设定 c t , h t ∈ R D c_{t}, h_{t} \in \mathbb{R}^{D} ct,ht∈RD, 而 x t ∈ R M x_{t} \in \mathbb{R}^{M} xt∈RM。可以看出,gate都采用了sigmoid函数(黄色部分),而“状态线”(蓝色和红色部分)上的函数都采用了非线性激活函数 tanh \tanh tanh。
内部状态
c
t
c_{t}
ct的计算为
c
t
=
f
t
⊙
c
t
−
1
+
i
t
⊙
g
t
h
t
=
tanh
(
c
t
)
m
t
=
o
t
⊙
h
t
\begin{aligned} c_{t} &= f_{t} \odot c_{t-1}+i_{t} \odot g_{t} \\ h_{t} &= \tanh(c_{t})\\ m_{t} &= o_{t} \odot h_{t} \end{aligned}
cthtmt=ft⊙ct−1+it⊙gt=tanh(ct)=ot⊙ht
其中,
f
t
,
i
t
,
o
t
∈
[
0
,
1
]
D
f_{t}, i_{t}, o_{t} \in[0,1]^{D}
ft,it,ot∈[0,1]D。在t时刻,
c
t
c_{t}
ct向量记录了到当前时刻为止的历史信息。
g
t
∈
R
D
g_{t}\in \mathbb{R}^{D}
gt∈RD是通过激活函数得到的候选状态
g
t
=
tanh
(
W
c
x
x
t
+
W
c
m
m
t
−
1
+
b
c
)
g_{t} = \tanh\left(W_{c x} x_{t}+W_{c m} m_{t-1}+b_{c}\right)
gt=tanh(Wcxxt+Wcmmt−1+bc)
外部输出是
y
t
=
ϕ
(
W
y
m
m
t
+
b
y
)
y_{t}=\phi\left(W_{ym} m_{t}+b_{y}\right)
yt=ϕ(Wymmt+by)
LSTM这三个门建立了门控机制。一个gate是一个维度为
D
D
D的“近似二值向量”,取值范围是[0, 1] (因为sigmoid函数)。 如果门向量其中一个元素为0,就代表“关闭”了与之相乘的状态向量中对应元素,不允许其通过;如果其中一个元素为1,就代表“开放”了状态向量中对应元素,允许其通过。
三个门的计算方式为
i
t
=
σ
(
W
i
x
x
t
+
W
i
m
m
t
−
1
+
W
i
c
c
t
−
1
+
b
i
)
f
t
=
σ
(
W
f
x
x
t
+
W
f
m
m
t
−
1
+
W
f
c
c
t
−
1
+
b
f
)
o
t
=
σ
(
W
o
x
x
t
+
W
o
m
m
t
−
1
+
W
o
c
c
t
+
b
o
)
\begin{aligned} i_{t}&=\sigma\left(W_{i x} x_{t}+W_{i m} m_{t-1}+W_{i c} c_{t-1}+b_{i}\right) \\ f_{t}&=\sigma\left(W_{f x} x_{t}+W_{f m} m_{t-1}+W_{f c} c_{t-1}+b_{f}\right)\\ o_{t}&=\sigma\left(W_{o x} x_{t}+W_{o m} m_{t-1}+W_{o c} c_{t}+b_{o}\right) \end{aligned}
itftot=σ(Wixxt+Wimmt−1+Wicct−1+bi)=σ(Wfxxt+Wfmmt−1+Wfcct−1+bf)=σ(Woxxt+Wommt−1+Wocct+bo)
三个门对应的作用是:
- 在对输入状态 x t x_{t} xt进行了一次与上一时刻隐藏状态 m t − 1 m_{t-1} mt−1的加和后,又进行了一次非线性激活 tanh \tanh tanh,我们有了第一个隐藏层输出 g t g_{t} gt作为候选状态。然后输入门 i t i_{t} it控制当前时刻 t t t的候选状态 g t g_{t} gt有多少信息需要保存。
- f t f_{t} ft决定我们会从细胞状态中丢弃什么信息, 控制上一个时刻 t − 1 t-1 t−1的内部状态 c t − 1 c_{t-1} ct−1需要遗忘多少信息。
- 在输入门和输出门对上一时刻内部状态和当前时刻候选状态进行加权平均后,我们得到了新的内部状态 c t c_{t} ct, 然后再进行一次非线性激活 tanh \tanh tanh,有了第二个隐藏层输出 h t h_{t} ht。然后输出门 o t o_{t} ot控制当前时刻的内部状态 c t , h t c_{t}, h_{t} ct,ht有多少信息需要输出给外部状态 m t m_{t} mt。
显然,当 f t = 0 , i t = 1 f_t=0, i_t=1 ft=0,it=1时,记忆单元将历史信息清空,只把当前候选状态 g t g_t gt写入。但是不要忘了 g t g_{t} gt依赖于 m t − 1 m_{t-1} mt−1,因此记忆单元 c t c_t ct其实依然和上一时刻的历史信息有所关联。当 f t = 1 , i t = 0 f_t=1, i_t=0 ft=1,it=0时,那么记忆单元 c t c_t ct,还有 h t h_t ht就复制了上一时刻的内容。但是 m t m_t mt和 y t y_t yt就不光上一时刻的内容有关,也与 x t x_{t} xt有关,因为 o t o_{t} ot依赖于 x t x_{t} xt。
当然因为门控函数是sigmoid函数,这里的门只是"软门",上一时刻的信息和当前输入会以一定的比例进行加权平均。
【这里我们可以尝试回忆RNN作为对比。1. RNN只有一个隐藏层激活函数,对上一个时刻的隐藏状态和当前输入一起进行激活。而LSTM有两个隐藏层激活函数,首先对上一个时刻的隐藏状态和当前输入一起进行一次激活,然后对激活后的输出做记忆选择,再对选择结果进行一次激活;第二次激活的输出再被输出门进行一次输出选择。最后传递到外部,走两个路径:输出结果和输出隐藏状态。2. 这里的 m t m_{t} mt对应于RNN中的 S t S_{t} St。这里的 g t g_{t} gt对应RNN的 h t h_{t} ht。原来的RNN只有一个参数矩阵 W W W连接不同时刻,对应于这里的 W c m W_{cm} Wcm,但是这里多了其他与 c t − 1 c_{t-1} ct−1相关的参数矩阵连接不同时刻。3. RNN中的矩阵 U U U对应这里的 W c m W_{cm} Wcm, V V V对应这里的 W y m W_{ym} Wym】
如何理解这里的记忆?
我们可以简单地把cell state这条线理解为一部电影中的主线剧情,因为它承载着从电影一开始的信息记忆细节。在电影剧情不断发展得时候,input就相当于不断推动整个剧情的支线剧情,不断地添加到主线剧情中,每条支线剧情的作用就由遗忘门和输入门两个门进行加权。
后向传播
向量梯度
首先考虑对输入输出向量的梯度。依然是看每一个向量的输出箭头有几个。从输出层往下看
∂
J
∂
m
t
=
∂
J
∂
y
t
W
y
m
T
+
∂
J
∂
g
t
+
1
~
W
c
m
T
+
∂
J
∂
i
t
+
1
~
W
i
m
T
+
∂
J
∂
f
t
+
1
~
W
f
m
T
+
∂
J
∂
o
t
+
1
~
W
o
m
T
∂
J
∂
h
t
=
∂
J
∂
m
t
⊙
o
t
∂
J
∂
c
t
=
∂
J
∂
h
t
d
h
t
d
c
t
+
∂
J
∂
c
t
+
1
⊙
f
t
+
1
+
∂
J
∂
i
t
+
1
W
i
c
T
+
∂
J
∂
f
t
+
1
W
f
c
T
+
+
∂
J
∂
o
t
W
o
c
T
\begin{aligned} \frac{\partial J}{\partial m_{t}} &=\frac{\partial J}{\partial y_{t}} W_{y m}^{T}+\frac{\partial J}{\partial \widetilde{g_{t+1}}} W_{c m}^{T}+\frac{\partial J}{\partial \widetilde{i_{t+1}}} W_{i m}^{T}+\frac{\partial J}{\partial \widetilde{f_{t+1}}} W_{f m}^{T}+\frac{\partial J}{\partial \widetilde{o_{t+1}}} W_{o m}^{T} \\ \frac{\partial J}{\partial h_{t}} &=\frac{\partial J}{\partial m_{t}} \odot o_{t} \\ \frac{\partial J}{\partial c_{t}} &=\frac{\partial J}{\partial h_{t}} \frac{d h_{t}}{d c_{t}}+\frac{\partial J}{\partial c_{t+1}} \odot f_{t+1}+\frac{\partial J}{\partial i_{t+1}} W_{i c}^{T}+\frac{\partial J}{\partial f_{t+1}} W_{f c}^{T}++\frac{\partial J}{\partial o_{t}} W_{o c}^{T} \end{aligned}
∂mt∂J∂ht∂J∂ct∂J=∂yt∂JWymT+∂gt+1
∂JWcmT+∂it+1
∂JWimT+∂ft+1
∂JWfmT+∂ot+1
∂JWomT=∂mt∂J⊙ot=∂ht∂Jdctdht+∂ct+1∂J⊙ft+1+∂it+1∂JWicT+∂ft+1∂JWfcT++∂ot∂JWocT
这里
d
h
t
d
c
t
=
1
−
h
t
2
\dfrac{dh_t}{dc_t}=1-h_{t}^{2}
dctdht=1−ht2.
然后往下到对四个门的梯度,注意,这里是求激活函数和门函数后的输出的梯度
∂
J
∂
g
t
=
∂
J
∂
c
t
⊙
i
t
∂
J
∂
i
t
=
∂
J
∂
c
t
⊙
g
t
∂
J
∂
f
t
=
∂
J
∂
c
t
⊙
c
t
−
1
∂
J
∂
O
t
=
∂
J
∂
m
t
⊙
h
t
\begin{aligned} \frac{\partial J}{\partial g_{t}} &=\frac{\partial J}{\partial c_{t}} \odot i_{t} \\ \frac{\partial J}{\partial i_{t}} &=\frac{\partial J}{\partial c_{t}} \odot g_{t} \\ \frac{\partial J}{\partial f_{t}} &=\frac{\partial J}{\partial c_{t}} \odot c_{t-1} \\ \frac{\partial J}{\partial O_{t}} &=\frac{\partial J}{\partial m_{t}} \odot h_{t} \end{aligned}
∂gt∂J∂it∂J∂ft∂J∂Ot∂J=∂ct∂J⊙it=∂ct∂J⊙gt=∂ct∂J⊙ct−1=∂mt∂J⊙ht
然后继续往下求激活函数和门函数后的输入的梯度
∂
J
∂
i
~
t
=
∂
J
∂
i
t
d
i
t
d
i
~
t
=
∂
J
∂
i
t
i
t
(
1
−
i
t
)
∂
J
∂
f
~
t
=
∂
J
∂
f
t
d
f
t
d
f
~
t
=
∂
J
∂
f
t
f
t
(
1
−
f
t
)
∂
J
∂
o
~
t
=
∂
J
∂
o
t
d
o
t
d
o
~
t
=
∂
J
∂
o
t
o
t
(
1
−
o
t
)
∂
J
∂
g
~
t
=
∂
J
∂
g
t
d
g
t
d
g
~
t
=
∂
J
∂
g
t
(
1
−
g
t
2
)
\begin{aligned} \dfrac{\partial J}{\partial \tilde{i}_{t}}&=\dfrac{\partial J}{\partial i_{t}} \dfrac{d i_{t}}{d \tilde{i}_{t}}=\dfrac{\partial J}{\partial i_{t}} i_{t}\left(1-i_{t}\right)\\ \dfrac{\partial J}{\partial \tilde{f}_{t}}&=\dfrac{\partial J}{\partial f_{t}} \dfrac{d f_{t}}{d \tilde{f}_{t}}=\dfrac{\partial J}{\partial f_{t}} f_{t}\left(1-f_{t}\right)\\ \dfrac{\partial J}{\partial \tilde{o}_{t}}&=\dfrac{\partial J}{\partial o_{t}} \dfrac{d o_{t}}{d \tilde{o}_{t}}=\dfrac{\partial J}{\partial o_{t}} o_{t}\left(1-o_{t}\right)\\ \dfrac{\partial J}{\partial \tilde{g}_{t}}&=\dfrac{\partial J}{\partial g_{t}} \dfrac{d g_{t}}{d \tilde{g}_{t}}=\dfrac{\partial J}{\partial g_{t}}\left(1-g_{t}^{2}\right) \end{aligned}
∂i~t∂J∂f~t∂J∂o~t∂J∂g~t∂J=∂it∂Jdi~tdit=∂it∂Jit(1−it)=∂ft∂Jdf~tdft=∂ft∂Jft(1−ft)=∂ot∂Jdo~tdot=∂ot∂Jot(1−ot)=∂gt∂Jdg~tdgt=∂gt∂J(1−gt2)
最后求对输入
x
t
x_t
xt的梯度
∂
J
∂
x
t
=
∂
J
∂
g
t
~
W
c
x
T
+
∂
J
∂
i
t
~
W
i
x
T
+
∂
J
∂
f
t
~
W
f
x
T
+
∂
J
∂
o
t
~
W
o
x
T
\frac{\partial J}{\partial x_{t}}=\frac{\partial J}{\partial \tilde{g_{t}}} W_{c x}^{T}+\frac{\partial J}{\partial \tilde{i_{t}}} W_{i x}^{T}+\frac{\partial J}{\partial \tilde{f_{t}}} W_{f x}^{T}+\frac{\partial J}{\partial \tilde{o_{t}}} W_{o x}^{T}
∂xt∂J=∂gt~∂JWcxT+∂it~∂JWixT+∂ft~∂JWfxT+∂ot~∂JWoxT
参数梯度
还是从输出层出发,首先对
m
t
m_t
mt相关的五个权重矩阵
∂
J
∂
W
y
m
=
m
t
T
∂
J
∂
y
t
∂
J
∂
W
c
m
=
m
t
T
∂
J
∂
g
t
+
1
∂
J
∂
W
i
m
=
m
t
T
∂
J
∂
i
t
+
1
∂
J
∂
W
f
m
=
m
t
T
∂
J
∂
f
t
+
1
∂
J
∂
W
o
m
=
m
t
T
∂
J
∂
o
t
+
1
\begin{aligned} \frac{\partial J}{\partial W_{y m}} &=m_{t}^{T} \frac{\partial J}{\partial y_{t}} \\ \frac{\partial J}{\partial W_{c m}} &=m_{t}^{T} \frac{\partial J}{\partial g_{t+1}} \\ \frac{\partial J}{\partial W_{i m}} &=m_{t}^{T} \frac{\partial J}{\partial i_{t+1}} \\ \frac{\partial J}{\partial W_{f m}} &=m_{t}^{T} \frac{\partial J}{\partial f_{t+1}} \\ \frac{\partial J}{\partial W_{o m}} &=m_{t}^{T} \frac{\partial J}{\partial o_{t+1}} \end{aligned}
∂Wym∂J∂Wcm∂J∂Wim∂J∂Wfm∂J∂Wom∂J=mtT∂yt∂J=mtT∂gt+1∂J=mtT∂it+1∂J=mtT∂ft+1∂J=mtT∂ot+1∂J
与RNN情况类似, 然后每一个都要加和。注意时刻之间的参数连接,加和索引是
t
−
1
t-1
t−1。
∂
J
∂
W
y
m
=
∑
i
=
1
t
m
i
T
∂
J
∂
y
i
∂
J
∂
W
c
m
=
∑
i
=
1
t
−
1
m
i
T
∂
J
∂
g
i
+
1
∂
J
∂
W
i
m
=
∑
k
=
1
t
−
1
m
k
T
∂
J
∂
i
k
+
1
∂
J
∂
W
f
m
=
∑
i
=
1
t
−
1
m
i
T
∂
J
∂
f
i
+
1
∂
J
∂
W
o
m
=
∑
i
=
1
t
−
1
m
i
T
∂
J
∂
o
i
+
1
\begin{aligned} \frac{\partial J}{\partial W_{y m}} &=\sum_{i=1}^{t}m_{i}^{T} \frac{\partial J}{\partial y_{i}} \\ \frac{\partial J}{\partial W_{c m}} &=\sum_{i=1}^{t-1}m_{i}^{T} \frac{\partial J}{\partial g_{i+1}} \\ \frac{\partial J}{\partial W_{i m}} &=\sum_{k=1}^{t-1}m_{k}^{T} \frac{\partial J}{\partial i_{k+1}} \\ \frac{\partial J}{\partial W_{f m}} &=\sum_{i=1}^{t-1}m_{i}^{T} \frac{\partial J}{\partial f_{i+1}} \\ \frac{\partial J}{\partial W_{o m}} &=\sum_{i=1}^{t-1}m_{i}^{T} \frac{\partial J}{\partial o_{i+1}} \end{aligned}
∂Wym∂J∂Wcm∂J∂Wim∂J∂Wfm∂J∂Wom∂J=i=1∑tmiT∂yi∂J=i=1∑t−1miT∂gi+1∂J=k=1∑t−1mkT∂ik+1∂J=i=1∑t−1miT∂fi+1∂J=i=1∑t−1miT∂oi+1∂J
然后往下对
c
t
c_t
ct相关的三个权重矩阵求梯度
∂
J
∂
W
i
c
=
c
t
T
∂
J
∂
i
t
+
1
∂
J
∂
W
f
c
=
c
t
T
∂
J
∂
f
t
+
1
∂
J
∂
W
o
c
=
c
t
T
∂
J
∂
o
t
\begin{aligned} \frac{\partial J}{\partial W_{i c}} &=c_{t}^{T} \frac{\partial J}{\partial i_{t+1}} \\ \frac{\partial J}{\partial W_{f c}} &=c_{t}^{T} \frac{\partial J}{\partial f_{t+1}} \\ \frac{\partial J}{\partial W_{o c}} &=c_{t}^{T} \frac{\partial J}{\partial o_{t}} \end{aligned}
∂Wic∂J∂Wfc∂J∂Woc∂J=ctT∂it+1∂J=ctT∂ft+1∂J=ctT∂ot∂J
同样做求和。
最后针对
x
t
x_t
xt相关的四个权重矩阵求梯度
∂
J
∂
W
c
x
=
x
t
T
∂
J
∂
g
t
∂
J
∂
W
i
x
=
x
t
T
∂
J
∂
i
t
∂
J
∂
W
f
x
=
x
t
T
∂
J
∂
f
t
∂
J
∂
W
o
x
=
x
t
T
∂
J
∂
o
t
\begin{aligned} \frac{\partial J}{\partial W_{c x}} &=x_{t}^{T} \frac{\partial J}{\partial g_{t}} \\ \frac{\partial J}{\partial W_{i x}} &=x_{t}^{T} \frac{\partial J}{\partial i_{t}} \\ \frac{\partial J}{\partial W_{f x}} &=x_{t}^{T} \frac{\partial J}{\partial f_{t}} \\ \frac{\partial J}{\partial W_{o x}} &=x_{t}^{T} \frac{\partial J}{\partial o_{t}} \end{aligned}
∂Wcx∂J∂Wix∂J∂Wfx∂J∂Wox∂J=xtT∂gt∂J=xtT∂it∂J=xtT∂ft∂J=xtT∂ot∂J
同样加和。
其实本质上还是应用全连接网络的后向传播求梯度思路,只不过这里一个LSTM有两个隐藏层,多了三个门函数。
简化版本的LSTM
简化版本就不考虑cell state对内部三个门在当前时刻和下一时刻的反馈。
三个门的计算方式为
i
t
=
σ
(
W
i
x
x
t
+
W
i
m
m
t
−
1
+
b
i
)
f
t
=
σ
(
W
f
x
x
t
+
W
f
m
m
t
−
1
+
b
f
)
o
t
=
σ
(
W
o
x
x
t
+
W
o
m
m
t
−
1
+
b
o
)
\begin{aligned} i_{t}&=\sigma\left(W_{i x} x_{t}+W_{i m} m_{t-1}+b_{i}\right) \\ f_{t}&=\sigma\left(W_{f x} x_{t}+W_{f m} m_{t-1}+b_{f}\right)\\ o_{t}&=\sigma\left(W_{o x} x_{t}+W_{o m} m_{t-1}+b_{o}\right) \end{aligned}
itftot=σ(Wixxt+Wimmt−1+bi)=σ(Wfxxt+Wfmmt−1+bf)=σ(Woxxt+Wommt−1+bo)
最后所有公式可以简化表达为
[
g
t
o
t
i
t
f
t
]
=
[
tanh
σ
σ
σ
]
(
W
[
x
t
m
t
−
1
]
+
b
)
\left[\begin{array}{c} g_{t} \\ o_{t} \\ i_{t} \\ f_{t} \end{array}\right]=\left[\begin{array}{c} \tanh \\ \sigma \\ \sigma \\ \sigma \end{array}\right]\left(W\left[\begin{array}{c} x_{t} \\ m_{t-1} \end{array}\right]+b\right)
⎣⎢⎢⎡gtotitft⎦⎥⎥⎤=⎣⎢⎢⎡tanhσσσ⎦⎥⎥⎤(W[xtmt−1]+b)
其中
x
t
∈
R
M
x_t \in \mathbb{R}^{M}
xt∈RM,
W
∈
R
4
D
×
(
M
+
D
)
W \in \mathbb{R}^{4D\times (M+D)}
W∈R4D×(M+D),
b
∈
R
4
D
b \in \mathbb{R}^{4D}
b∈R4D,
m
t
−
1
∈
R
D
m_{t-1} \in \mathbb{R}^{D}
mt−1∈RD。