深度学习经典结构之长短期记忆网络LSTM
一、 背景
普通循环神经网络
简单的循环神经网络如下图所示,其传递公式为
h
t
=
f
(
U
h
t
−
1
+
W
x
t
+
b
)
y
t
=
V
h
t
h_t=f(Uh_{t-1}+Wx_t+b)\\y_t=Vh_t
ht=f(Uht−1+Wxt+b)yt=Vht,其中f(·)为激活函数,ht为单元输出.
该结构存在两个问题,
- 梯度爆炸
传递过程中梯度可能会过大,从而导致梯度爆炸问题 - 记忆容量
随着𝒉𝑡 不断累积存储新的输入信息,会发生饱和现象.假设 𝑔(⋅) 为 Logistic 函数,则随着时间 𝑡 的增长,𝒉𝑡 会变得越来越大,从而导致𝒉变得饱和.因此,隐状态𝒉𝑡 可以存储的信息是有限的,
随着记忆单元存储的内容越来越多,其丢失的信息也越来越多.
由此引入门控机制的模型结构
二、LSTM
1、原理
1)、结构
引入了三个门,添加了一个内部状态
c
t
−
1
,
h
t
−
1
c_{t-1},h_{t-1}
ct−1,ht−1则为外部状态,其三个门可以理解为下面的作用
(1) 遗忘门𝒇𝑡 控制上一个时刻的内部状态
c
t
−
1
c_{t-1}
ct−1 需要遗忘多少信息.
(2) 输入门𝒊𝑡 控制当前时刻的候选状态
c
t
~
\tilde{c_t}
ct~有多少信息需要保存.
(3) 输出门 𝒐𝑡 控制当前时刻的内部状态
c
t
c_{t}
ct有多少信息需要输出给外部状态
h
t
h_{t}
ht.
其公式如下
o
u
t
p
u
t
t
=
f
(
d
o
t
(
h
t
,
U
o
)
+
d
o
t
(
x
t
,
W
o
)
+
d
o
t
(
C
t
,
V
o
)
+
b
o
)
output_t = f(dot(h_t, Uo) + dot(x_t, Wo) + dot(C_t, Vo) + bo)
outputt=f(dot(ht,Uo)+dot(xt,Wo)+dot(Ct,Vo)+bo)
门的计算:
i
t
=
f
(
d
o
t
(
h
t
−
1
,
U
i
)
+
d
o
t
(
x
t
,
W
i
)
+
b
i
)
f
t
=
f
(
d
o
t
(
h
t
−
1
,
U
f
)
+
d
o
t
(
x
t
,
W
f
)
+
b
f
)
o
t
=
f
(
d
o
t
(
h
t
−
1
,
U
k
)
+
d
o
t
(
x
t
,
W
k
)
+
b
k
)
i_t = f(dot(h_{t-1}, Ui) + dot(x_t, Wi) + bi) \\ f_t = f(dot(h_{t-1}, Uf) + dot(x_t, Wf) + bf) \\ o_t = f(dot(h_{t-1}, Uk) + dot(x_t, Wk) + bk)
it=f(dot(ht−1,Ui)+dot(xt,Wi)+bi)ft=f(dot(ht−1,Uf)+dot(xt,Wf)+bf)ot=f(dot(ht−1,Uk)+dot(xt,Wk)+bk)
内部状态更新与外部状态更新:
{
c
t
~
=
t
a
n
h
(
h
t
−
1
+
x
t
)
c
t
=
d
o
t
(
f
t
,
c
t
−
1
)
+
d
o
t
(
i
t
,
c
t
~
)
h
t
=
d
o
t
(
o
t
,
t
a
n
h
(
c
t
)
)
\begin{cases} \tilde{c_t}=tanh(h_{t-1}+x_t) \\ c_t=dot(f_t,c_{t-1})+dot(i_t,\tilde{c_t}) \end{cases}\\ h_t=dot(o_t,tanh(c_t))
{ct~=tanh(ht−1+xt)ct=dot(ft,ct−1)+dot(it,ct~)ht=dot(ot,tanh(ct))
,其中dot表示点乘,outout_t为LSTM单元的输出
2)、参数学习
采用梯度下降法,有随时间反向传播(BPTT)算法和实时循环学习(RTRL)算法.
- 时间反向传播(BPTT)算法:从最后一步求导,反向逐渐得到初始步的导数;一般网络输出
维度远低于输入维度,因此 BPTT 算法的计算量会更小,但是 BPTT 算法需要保
存所有时刻的中间梯度,空间复杂度较高 - 实时循环学习(RTRL)算法:基于前向传播算法,从计算初始步长导数开始,逐步计算到最后一步导数.RTRL算法不需要梯度回传,省略了梯度的空间存储,因此非常适合用于需要在线学习或无限序列的任务中
3)、如何实现记忆长短存储,解决梯度爆炸
从三个门的公式可以看到,其激活函数与状态生成的激活函数不一样,门的激活函数选用sigmoid类函数,取值在[0,1]之间,符合门的定义,是实现长短期记忆的功能控制开关;
状态激活函数tanh,取值[-1,1],是一个0值中心化的函数,在0附近的梯度较大,收敛快.
- 记忆单元 𝒄 可以在某个时刻捕捉到某个关键信息,并有能力将此关键信息保存一定的时间间隔.记忆单元 𝒄 中保存信息的生命周期要长于短期记忆 𝒉,但又远远短于长期记忆,因此叫长短期记忆.
- 梯度爆炸来源于链式法则的连乘,我们看一下梯度
∂
C
t
∂
C
t
−
1
\frac {\partial C_t} {\partial C_{t-1}}
∂Ct−1∂Ct
∂ C t ∂ C t − 1 = ∂ d o t ( f t , c t − 1 ) ∂ C t − 1 + ∂ d o t ( i t , c t ~ ) ∂ C t − 1 \frac {\partial C_t} {\partial C_{t-1}}=\frac {\partial dot(f_t,c_{t-1})} {\partial C_{t-1}}+ { \frac {\partial {dot(i_t,\tilde{c_t})}} {\partial C_{t-1}} } ∂Ct−1∂Ct=∂Ct−1∂dot(ft,ct−1)+∂Ct−1∂dot(it,ct~)解得
可以看到梯度部分是个加项,而 三个门都是神经网网络自己学到的,因此可以通过学习改变门控的值,使梯度维持在合理的范围内.经使用经验得知时序超过100步长,该结构仍有梯度消失问题.
4)、正则化
Keras的每个循环层都有两个与 dropout 相关的参数:一个是 dropout,它是一个浮点数,指定该层
输入单元的 dropout 比率;另一个是 recurrent_dropout,指定循环单元的 dropout 比率。因为使用 dropout正则化的网络总是需要更长的时间才能完全收敛,所以网络训练轮次增加为原来的 2 倍。
2、常见的LSTM的变体
1)、LSTM 单元结构变体门控循环单元GRU
GRU不引入额外的状态,而是引入的一个更新门与重置门.其公式如下
r
t
=
f
(
d
o
t
(
h
t
−
1
,
U
r
)
+
d
o
t
(
x
t
,
W
r
)
+
b
r
)
z
t
=
f
(
d
o
t
(
h
t
−
1
,
U
z
)
+
d
o
t
(
x
t
,
W
z
)
+
b
z
)
h
t
~
=
t
a
n
h
(
d
o
t
(
x
t
,
W
h
)
+
d
o
t
(
d
o
t
(
r
t
,
h
t
−
1
)
,
U
h
)
+
b
h
)
h
t
=
f
(
z
t
h
t
−
1
+
(
1
−
z
t
)
h
t
~
)
r_t=f(dot(h_{t-1},U_r)+dot(x_{t},W_r)+b_r)\\ z_t=f(dot(h_{t-1},U_z)+dot(x_{t},W_z)+b_z)\\ \tilde{h_t}=tanh(dot(x_t,W_h)+dot(dot(r_t,h_{t-1}),U_h)+b_h)\\ h_t=f(z_t h_{t-1}+(1-z_t)\tilde{h_t})
rt=f(dot(ht−1,Ur)+dot(xt,Wr)+br)zt=f(dot(ht−1,Uz)+dot(xt,Wz)+bz)ht~=tanh(dot(xt,Wh)+dot(dot(rt,ht−1),Uh)+bh)ht=f(ztht−1+(1−zt)ht~)
在LSTM网络中,输入门和遗忘门是互补关系,具有一定的冗余性.GRU网络直接使用一个门来控制输入和遗忘之间的平衡.
当 𝒛𝑡 = 0 时,当前状态 𝒉𝑡 和前一时刻的状态𝒉𝑡−1 之间为非线性函数关系
当𝒛𝑡 = 1时,𝒉𝑡 和𝒉𝑡−1 之间为线性函数关系.
当
r
t
r_{t}
rt= 0时,候选状态
h
t
~
\tilde{h_t}
ht~ 只和当前输入相关,和历史状态无关.
当
r
t
r_{t}
rt = 1时,候选状态
h
t
~
\tilde{h_t}
ht~和当前输入以及历史状态
h
t
−
1
h_{t-1}
ht−1 相关,和简单循环网络一致.
该变体的优点在于结构简单,计算要快于LSTM.
2)、LSTM网络结构变体
- 堆叠结构
遇到了性能瓶颈,所以我们应该考虑增加网络容量。增加网络容量的通常做法是增加每层单元数或增加层数.逐个堆叠循环层,所有中间层都应该返回完整的输出序列(一个 3D 张量),而不是只返回最后一个时间步的输出。 - 双向结构
在机器学习中,如果一种数据表示不同但有用,那么总是值得加以利用,这种表示与其他表示的差异越大越好,它们提供了查看数据的全新角度,抓住了数据中被其他方法忽略的内容,因此可以提高模型在某个任务上的性能。考虑到逆序可能提供到不一样的表示,因此提出了双向RNN,而该结构在文本数据集上表现突出,这证实了一个假设:虽然单词顺序对理解语言很重要,但使用哪种顺序并不重要。重要的是,在逆序序列上训练的 RNN 学到的表示不同于在原始序列上学到的表示.
三、参数计算
1、RNN
以keras框架为例,其接受的输入为(batch_size, timesteps, input_features),循环层隐藏单元数为level_nums,则该循环层的参数为
(
i
n
p
u
t
f
e
a
t
u
r
e
s
+
h
i
d
d
e
n
n
u
m
s
)
∗
h
i
d
d
e
n
n
u
m
s
+
h
i
d
d
e
n
n
u
m
s
(inputfeatures+hiddennums)*hiddennums+hiddennums
(inputfeatures+hiddennums)∗hiddennums+hiddennums;假设输入数据为(None,None,32),输入到一个循环层SimpleRNN(32),则该循环层的参数=(32+32)*32+32=2080.
接下来举一个具体的输入例子进行理解,假设有输入(1,2,1),t0=1,t1=2;循环层隐藏单元数为2,则在计算的时候是输入与状态进行拼接✖️参数+偏差=输出,状态维度由循环层维度决定等于2,与输入拼接后维度为3,为保障下一状态输出维度仍为2,因此参数维度为(3,2),偏差维度为2.
2、LSTM
以keras框架为例,其接受的输入为(batch_size, timesteps, input_features),循环层隐藏单元数数为level_nums,则该循环层的参数
4
∗
(
(
i
n
p
u
t
f
e
a
t
u
r
e
s
+
h
i
d
d
e
n
n
u
m
s
)
∗
h
i
d
d
e
n
n
u
m
s
+
h
i
d
d
e
n
n
u
m
s
)
4*((inputfeatures+hiddennums)*hiddennums+hiddennums)
4∗((inputfeatures+hiddennums)∗hiddennums+hiddennums);假设输入数据为(None,None,32),输入到一个循环层LSTM(32),则该循环层的参数=4*((32+32)*32+32)=8320.
根据LSTM的单元结构看,有四组权重,分别为遗忘门,输入门,输出门,跟候选记忆生成,每个权重都是一个全链接,其输入均为前一状态与当前输入.因此公式可以得到.
3、GRU
以keras框架为例,其接受的输入为(batch_size, timesteps, input_features),循环层隐藏单元数数为level_nums,则该循环层的参数
3
∗
(
(
i
n
p
u
t
f
e
a
t
u
r
e
s
+
h
i
d
d
e
n
n
u
m
s
)
∗
h
i
d
d
e
n
n
u
m
s
+
h
i
d
d
e
n
n
u
m
s
)
3*((inputfeatures+hiddennums)*hiddennums+hiddennums)
3∗((inputfeatures+hiddennums)∗hiddennums+hiddennums);假设输入数据为(None,None,32),输入到一个循环层LSTM(32),则该循环层的参数=3*((32+32)*32+32)=6240.
根据GRU的单元结构看,有3组权重,分别为更新门,重置门,跟候选状态生成,每个权重都是一个全链接,其输入均为前一状态与当前输入.因此公式可以得到.