本文主要参考李宏毅老师的视频介绍RNN相关知识,主要包括两个部分:
- 分别介绍Navie RNN,LSTM,GRU的结构
- 对比这三者的优缺点
1.RNN,LSTM,GRU结构及计算方式
1.1 Navie RNN
结构图:
计算公式:
h
t
=
σ
(
W
h
h
t
−
1
+
W
x
x
t
)
y
t
=
σ
(
W
y
h
t
)
h^t=\sigma(W^hh^{t-1}+W^xx^t)\\y^t=\sigma(W^yh^t)
ht=σ(Whht−1+Wxxt)yt=σ(Wyht)
依赖每一个时刻的隐状态产生当前的输出,具体计算方式根据自己任务来定。
1.2 LSTM
结构图:
计算公式:
Z
i
=
σ
(
W
i
[
h
t
−
1
,
x
t
]
)
Z
f
=
σ
(
W
f
[
h
t
−
1
,
x
t
]
)
Z
o
=
σ
(
W
o
[
h
t
−
1
,
x
t
]
)
Z
=
t
a
n
h
(
W
[
h
t
−
1
,
x
t
]
)
C
t
=
Z
f
⊙
C
t
−
1
+
Z
i
⊙
Z
h
t
=
Z
o
⊙
t
a
n
h
C
t
y
t
=
σ
(
W
y
h
t
)
Z^i=\sigma(W_i[h^{t-1},x_t])\\Z^f=\sigma(W_f[h^{t-1},x_t])\\Z^o=\sigma(W_o[h^{t-1},x_t])\\Z=\mathop{tanh}(W[h^{t-1},x_t])\\C_t=Z^f\odot C^{t-1}+Z^i\odot{Z}\\h^t=Z^o\odot\mathop{tanh}C^t\\y^t=\sigma(W^yh^t)
Zi=σ(Wi[ht−1,xt])Zf=σ(Wf[ht−1,xt])Zo=σ(Wo[ht−1,xt])Z=tanh(W[ht−1,xt])Ct=Zf⊙Ct−1+Zi⊙Zht=Zo⊙tanhCtyt=σ(Wyht)
1.3 GRU
结构图:
计算公式:
r
=
σ
(
W
r
[
h
t
−
1
,
x
t
]
)
z
=
σ
(
W
r
[
h
t
−
1
,
x
t
]
)
h
t
−
1
′
=
r
⊙
h
t
−
1
h
′
=
t
a
n
h
(
W
h
t
−
1
′
)
h
t
=
(
1
−
z
)
⊙
h
t
−
1
+
z
⊙
h
′
r=\sigma(W^r[h^{t-1},x^t])\\z=\sigma(W^r[h^{t-1},x^t])\\h^{t-1'}=r\odot h^{t-1}\\h'=\mathop{tanh}(Wh^{t-1'})\\h^t=(1-z)\odot h^{t-1}+z\odot h'
r=σ(Wr[ht−1,xt])z=σ(Wr[ht−1,xt])ht−1′=r⊙ht−1h′=tanh(Wht−1′)ht=(1−z)⊙ht−1+z⊙h′
2.RNN,LSTM,GRU的优缺点
2.1 为什么LSTM能解决RNN不能长期依赖的问题
(1)RNN的梯度消失问题导致不能“长期依赖”
RNN中的梯度消失不是指损失对参数的总梯度消失了,而是RNN中对较远时间步的梯度消失了。RNN中反向传播使用的是back propagation through time(BPTT)方法,损失loss对参数W的梯度等于loss在各时间步对w求导之和。用公式表示就是:
∂
E
∂
W
h
=
∑
i
=
1
t
∂
E
∂
y
t
∂
y
t
∂
h
t
∂
h
t
∂
h
i
∂
h
i
∂
W
h
(1)
\frac{\partial E}{\partial W^h}=\sum_{i=1}^t\frac{\partial E}{\partial y^t}\frac{\partial y_t}{\partial h_t}\frac{\partial h_t}{\partial h^i}\frac{\partial h_i}{\partial W^h}\tag1
∂Wh∂E=i=1∑t∂yt∂E∂ht∂yt∂hi∂ht∂Wh∂hi(1)
上式中
∂
h
t
∂
h
i
\frac{\partial h_t}{\partial h^i}
∂hi∂ht计算较复杂,根据复合函数求导方式连续求导。
∂
h
t
∂
h
i
=
∏
k
=
i
+
1
t
∂
h
k
∂
h
k
−
1
(2)
\frac{\partial h_t}{\partial h^i}=\prod_{k=i+1}^t\frac{\partial h_k}{\partial h^{k-1}}\tag2
∂hi∂ht=k=i+1∏t∂hk−1∂hk(2)
∂
h
k
∂
h
k
−
1
\frac{\partial h_k}{\partial h^{k-1}}
∂hk−1∂hk是当前隐状态对上一隐状态求偏导。
∂
h
k
∂
h
k
−
1
=
σ
′
W
h
\frac{\partial h_k}{\partial h^{k-1}}=\sigma'W^h
∂hk−1∂hk=σ′Wh
假设某一时间步j距离t时间步相差了(t-j)时刻。则
∂
h
t
∂
h
i
=
∏
t
−
j
σ
′
W
h
\frac{\partial h_t}{\partial h^i}=\prod^{t-j}\sigma'W^h
∂hi∂ht=∏t−jσ′Wh
如果t-j很大,也就是j距离t时间步很远,当
s
i
g
m
a
′
W
h
>
1
sigma'W^h>1
sigma′Wh>1时,会产生梯度爆炸问题,
s
i
g
m
a
′
W
h
<
1
sigma'W^h<1
sigma′Wh<1时,会产生梯度消失问题。而当t-j很小时,也就是j时t的短期依赖,则不存在梯度消失/梯度爆炸的问题。一般会使用梯度裁剪解决梯度爆炸问题。所以主要分析梯度消失问题。
loss对时间步j的梯度值反映了时间步j对最终输出 y t y_t yt的影响程度。就是j对最终输出 y t y_t yt的影响程度越大,则loss对时间步j的梯度值也就越大。loss对时间步j的梯度值趋于0,就说明了j对最终输出 y t y_t yt没影响。
综上:距离时间步t较远的j的梯度会消失,j对最终输出 y t y_t yt没影响。也就是说RNN中不能长期依赖。
(2)LSTM如何解决梯度消失
LSTM设计的初衷就是让当前记忆单元对上一记忆单元的偏导为常数。如在1997年最初版本的LSTM,记忆细胞更新公式为:
C
t
=
C
t
−
1
+
Z
i
⊙
x
t
∂
C
t
∂
C
t
−
1
=
1
C^t=C^{t-1}+Z^i\odot x^t\\\frac{\partial C_t}{\partial C^{t-1}}=1
Ct=Ct−1+Zi⊙xt∂Ct−1∂Ct=1
后来为了避免记忆细胞无线增长,引入了“遗忘门”。更新公式为:
C
t
=
Z
f
⊙
C
t
−
1
+
Z
i
⊙
x
t
C^t=Z^f\odot C^{t-1}+Z^i\odot x^t\\
Ct=Zf⊙Ct−1+Zi⊙xt
此时连续偏导的值为:
∂
C
t
∂
C
t
−
1
=
Z
f
\frac{\partial C_t}{\partial C^{t-1}}=Z^f
∂Ct−1∂Ct=Zf
虽然
Z
f
Z^f
Zf是一个[0,1]区间的数值,不在满足当前记忆单元对上一记忆单元的偏导为常数。但通常会给遗忘门设置一个很大的偏置项,使得遗忘门在多数情况下是关闭的,只有在少数情况下开启。回顾下遗忘门的公式,这里我们加上了偏置b。
Z
f
=
σ
(
W
f
[
h
t
−
1
,
x
t
]
+
b
f
)
Z^f=\sigma(W_f[h^{t-1},x_t]+b^f)
Zf=σ(Wf[ht−1,xt]+bf)
趋向于1时,遗忘门关闭,趋向于0,时,遗忘门打开。通过设置大的偏置项,使得大多数遗忘门的值趋于1。也就缓解了由于小数连乘导致的梯度消失问题。
2.2 相较于LSTM,GRU的优势
GRU的参数量少,减少过拟合的风险
LSTM的参数量是Navie RNN的4倍(看公式),参数量过多就会存在过拟合的风险,GRU只使用两个门控开关,达到了和LSTM接近的结果。其参数量是Navie RNN的三倍