自然语言处理之LSTM、GRU
一、前言
循环神经网络RNN,无法处理长距离依赖问题,针对此问题,提出了LSTM和GRU。
二、LSTM: (Long short-term memory)
2.1 LSTM结构
- 上图中左侧为RNN,右侧为LSTM结构图:RNN在隐藏层中只传递一个状态值 h h h,LSTM不仅传递 h h h,还传递一个状态值 c c c,每一个隐藏层中的神经元都接收上一时刻传递的 h t − 1 h_{t-1} ht−1和 c t − 1 c_{t-1} ct−1,经过计算得到 h t h_{t} ht和 c t c_{t} ct再传入下一时刻。
- 上图是LSTM的某一隐藏层的局部结构,其中包含3个门控结构:红色方框的遗忘门、绿色方框的输入门、紫色方框的输出门,3个门中包含3个sigmoid函数和2和tanh函数。
- 使用sigmoid函数的原因是:sigmoid函数能够将输入映射到[0,1]空间中,那么咱们就可以根据映射之后的概率对于上一时刻传递的信息进行有选择的去除,保留和输出。比如sigmoid函数的值为1也就是门的全开状态,则代表所有的信息都被保留,如果sigmoid函数为0也就是门的全闭状态,则代表所有的信息都不被保留。
- 使用tanh函数是:为了对数据进行处理,映射到[-1,1]的空间。
- 说明:公式中 ⋅ · ⋅表示矩阵相乘, ⊗ \otimes ⊗表示点乘。
2.1.1 遗忘门
f
t
=
σ
(
W
f
⋅
[
h
t
−
1
,
x
t
]
+
b
f
)
(1)
f_t=\sigma(W_f·[h_{t-1},x_t]+b_f)\tag{1}
ft=σ(Wf⋅[ht−1,xt]+bf)(1)
- 将 t − 1 t-1 t−1时刻传入的 h t − 1 h_{t-1} ht−1与时刻 t t t的输入 x t x_t xt进行拼接,然后通过权值矩阵 W f W_f Wf转换后,加上偏置 b f b_f bf,最后通过sigmoid函数映射为 [ 0 , 1 ] [0,1] [0,1]范围内,形成遗忘门;
- 然后通过遗忘门
f
t
f_t
ft对上一时刻传入的
c
t
−
1
c_{t-1}
ct−1进行有选择的遗忘,将
c
t
−
1
c_{t-1}
ct−1与
f
t
f_t
ft进行点乘,得到去除一部分信息后的遗忘输出,所以遗忘门的输出值为:
C t − 1 ⊗ f t (2) C_{t-1}\otimes f_t\tag{2} Ct−1⊗ft(2)
2.1.2 输入门
i
t
=
σ
(
W
i
⋅
[
h
t
−
1
,
x
t
]
+
b
i
)
(3)
i_t=\sigma(W_i·[h_{t-1},x_t]+b_i)\tag{3}
it=σ(Wi⋅[ht−1,xt]+bi)(3)
C t ~ = t a n h ( W C ⋅ [ h t − 1 , x t ] + b C ) (4) \tilde{C_t}=tanh(W_C·[h_{t-1},x_t]+b_C)\tag{4} Ct~=tanh(WC⋅[ht−1,xt]+bC)(4)
- 输入门的输出值为:
i t ⊗ C t ~ (5) i_t\otimes\tilde{C_t}\tag{5} it⊗Ct~(5) - 将遗忘门的输出值和输入门的输出值加起来,就可以得到
C
t
C_t
Ct:
C t = C t − 1 ⊗ f t + i t ⊗ C t ~ (6) C_t=C_{t-1}\otimes f_t+i_t\otimes\tilde{C_t}\tag{6} Ct=Ct−1⊗ft+it⊗Ct~(6) - C t C_t Ct中保留了 t − 1 t-1 t−1时刻传入的部分信息和 t t t时刻传入的经过筛选后的信息。
2.1.3 输出门
o
t
=
σ
(
W
o
⋅
[
h
t
−
1
,
x
t
]
+
b
o
)
(7)
o_t=\sigma(W_o·[h_{t-1},x_t]+b_o)\tag{7}
ot=σ(Wo⋅[ht−1,xt]+bo)(7)
h
t
=
o
t
⊗
t
a
n
h
(
C
t
)
(8)
h_t=o_t\otimes tanh(C_t)\tag{8}
ht=ot⊗tanh(Ct)(8)
- 这样就计算出来t时刻的所有输出值,
h
t
h_t
ht和
C
t
C_t
Ct,然后
h
t
h_t
ht和
C
t
C_t
Ct又可以传入到下一时刻来进行循环操作了。
计算 t t t时刻的输出 y t y_t yt:
y t = g ( V ⋅ h t ) (9) y_t=g(V·h_t)\tag{9} yt=g(V⋅ht)(9)
上式中 V V V是隐藏层到输出层之间的权值矩阵, g ( ) g() g()是激活函数,如果是二分类采用 s i g m o i d sigmoid sigmoid,多分类则采用 s o f t m a x softmax softmax。
2.2 LSTM如何缓解RNN梯度消失问题
-
RNN导致梯度消失的原因:因为tanh和sigmoid函数的导数均小于1,一系列小于1的数连乘,连乘的数一多,连乘的结果就有很大概率为0,那么参数便不能进行更新了,从而导致的梯度消失现象发生。
-
虽然RNN也可以通过调整Ws来使得连乘接近于1,但是RNN是通过乘以Ws来调节,乘法数值变化较快,比较敏感,参数很难调,一不小心就超过了上界发生梯度爆炸,达不到下界不发生梯度消失。而LSTM是通过加上bf来调节,来降低梯度消失的风险,调节起来更容易,相对于RNN较好。所以之前也只是说了LSTM能相对于RNN缓解梯度消失的问题,并不能完全消除。
-
类比到LSTM中:
-
将 f t , i t , C t ~ f_t, i_t, \tilde{C_t} ft,it,Ct~带入 C t C_t Ct中可得:
C t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) C t − 1 + σ ( W i ⋅ [ h t − 1 , x t ] + b i ) t a n h ( W C ⋅ [ h t − 1 , x t ] + b C ) (10) C_t=\sigma(W_f·[h_{t-1},x_t]+b_f)C_{t-1}+\sigma(W_i·[h_{t-1},x_t]+b_i)tanh(W_C·[h_{t-1},x_t]+b_C)\tag{10} Ct=σ(Wf⋅[ht−1,xt]+bf)Ct−1+σ(Wi⋅[ht−1,xt]+bi)tanh(WC⋅[ht−1,xt]+bC)(10) -
C t C_t Ct对 C t − 1 C_{t-1} Ct−1求偏导结果为:
∂ C t ∂ C t − 1 = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) (11) \frac{\partial C_t}{\partial C_{t-1}}=\sigma(W_f·[h_{t-1},x_t]+b_f)\tag{11} ∂Ct−1∂Ct=σ(Wf⋅[ht−1,xt]+bf)(11)
上式中取值范围为 [ 0 , 1 ] [0,1] [0,1]之间,在实际参数更新过程中,可以控制 b f b_f bf较大,使得该值更接近于1,这样即使在多次连乘的情况下,梯度也不会消失。
2.3 LSTM优缺点
- 缺点:包含 W f , W i , W o , W c , b f , b i , b o , b c , V W_f,W_i,W_o,W_c,b_f,b_i,b_o,b_c,V Wf,Wi,Wo,Wc,bf,bi,bo,bc,V等9个参数,参数较多,调参对机器性能要求比较高;
- 优点:有更多的参数对于模型的调节更加精确。
三、GRU: (Gate Recurrent Unit)
3.1 GRU结构
3.2 GRU内部结构
3.2.1 reset重置门
r
t
=
σ
(
W
r
⋅
[
h
t
−
1
,
x
t
]
+
b
r
)
(12)
r_t=\sigma(W_r·[h_{t-1},x_t]+b_r)\tag{12}
rt=σ(Wr⋅[ht−1,xt]+br)(12)
得到重置门
r
t
r_t
rt后,将
r
t
r_t
rt与上一时刻传入的
h
t
−
1
h_{t-1}
ht−1进行点乘,得到重置之后的数据:
h
t
−
1
′
=
h
t
−
1
⊗
r
t
(13)
h_{t-1}'=h_{t-1}\otimes r_t\tag{13}
ht−1′=ht−1⊗rt(13)
然后将得到的
h
t
−
1
′
h_{t-1}'
ht−1′与
x
t
x_t
xt进行拼接:
h
′
=
t
a
n
h
(
W
⋅
[
h
t
−
1
′
,
x
t
]
+
b
)
(14)
h'=tanh(W·[h_{t-1}',x_t]+b)\tag{14}
h′=tanh(W⋅[ht−1′,xt]+b)(14)
这里的
h
′
h'
h′包含了输入信息
x
t
x_t
xt,和经过选择后的上一时刻的重要信息
h
t
−
1
′
h_{t-1}'
ht−1′,这样达到了记忆当前状态信息的目的。
3.2.2 update更新门
z
t
=
σ
(
W
z
⋅
[
h
t
−
1
,
x
t
]
+
b
z
)
(15)
z_t=\sigma(W_z·[h_{t-1},x_t]+b_z)\tag{15}
zt=σ(Wz⋅[ht−1,xt]+bz)(15)
z
t
z_t
zt就是更新门,更新门同时进行遗忘和记忆的方式如下:
h
t
=
z
t
⊗
h
t
−
1
+
(
1
−
z
t
)
⊗
h
′
(16)
h^t=z_t\otimes h^{t-1}+(1-z_t)\otimes h'\tag{16}
ht=zt⊗ht−1+(1−zt)⊗h′(16)
- 其中 z t ⊗ h t − 1 z_t\otimes h^{t-1} zt⊗ht−1:表示对原本隐藏状态的选择性遗忘, z t z_t zt看做是遗忘门,遗忘 h t − 1 h_{t-1} ht−1中不重要的信息;
- ( 1 − z t ) ⊗ h ′ (1-z_t)\otimes h' (1−zt)⊗h′:表示对包含当前节点信息的 h ′ h' h′进行选择性记忆;
- h t = z t ⊗ h t − 1 + ( 1 − z t ) ⊗ h ′ h^t=z_t\otimes h^{t-1}+(1-z_t)\otimes h' ht=zt⊗ht−1+(1−zt)⊗h′:遗忘上一时刻中 h t − 1 h^{t-1} ht−1的某些信息,并记忆当前节点输入的某些维度信息。
- 式(16)中遗忘的权重 z z z和记忆的权重 1 − z 1-z 1−z是互补的,遗忘多少信息,就弥补多少信息。
3.3 GRU总结
- GRU只有两个门,相应地参数也就比LSTM要少,效率要高,但是结果并没有多大的区别。