一、RNN基本结构、梯度消失和梯度爆炸的原因
线性计算单元组成的RNN结构是最简单的一种,我们以此为例来说明造成梯度消失和梯度爆炸的原因:
上图为线性计算单元组成的RNN.依据上图,我们现假设存在一个RNN模型仅包含一个隐藏层,整个RNN模型关注的时间步数为3,
H
0
H_{0}
H0是隐藏层的初始状态,则可以用如下算式表示前向传播过程:
H
1
=
W
X
X
1
+
W
H
H
0
+
b
H
,
O
1
=
W
O
H
1
+
b
O
H_{1} = W_{X}X_{1} + W_{H}H_{0} + b_{H}, O{1} = W_{O}H_{1} + b_{O}
H1=WXX1+WHH0+bH,O1=WOH1+bO
H
2
=
W
X
X
2
+
W
H
H
1
+
b
H
,
O
2
=
W
O
H
2
+
b
O
H_{2} = W_{X}X_{2} + W_{H}H_{1} + b_{H}, O{2} = W_{O}H_{2} + b_{O}
H2=WXX2+WHH1+bH,O2=WOH2+bO
H
3
=
W
X
X
3
+
W
H
H
2
+
b
H
,
O
3
=
W
O
H
3
+
b
O
H_{3} = W_{X}X_{3} + W_{H}H_{2} + b_{H}, O{3} = W_{O}H_{3} + b_{O}
H3=WXX3+WHH2+bH,O3=WOH3+bO
其中
X
=
[
X
0
,
X
1
,
X
2
,
.
.
X
t
]
X =[X_{0},X_{1},X_{2},..X_{t}]
X=[X0,X1,X2,..Xt],
Y
=
[
Y
0
,
Y
1
,
Y
2
,
.
.
Y
t
]
Y =[Y_{0},Y_{1},Y_{2},..Y_{t}]
Y=[Y0,Y1,Y2,..Yt]代表
t
t
t时间步长的特征数据和真值标签。
H
=
[
H
0
,
H
1
,
H
2
,
.
.
H
t
]
H =[H_{0},H_{1},H_{2},..H_{t}]
H=[H0,H1,H2,..Ht] 代表RNN中的隐节点状态,
W
X
,
W
H
,
W
O
W_{X},W_{H},W_{O}
WX,WH,WO是各层的权重,
b
H
,
b
O
b_{H},b_{O}
bH,bO是偏置。
因为是最简单的线性计算单元,我们假设最后使用MSE作为损失函数:
L
t
=
1
2
(
Y
t
−
O
t
)
2
L_{t} = \frac{1}{2}\left(Y{t}-O_{t}\right)^{2}
Lt=21(Yt−Ot)2
对于整个RNN的训练,我们通常需要统计每个时刻的损失之和或平均值,这里我们以每个时间步数累积损失作为整个模型的训练损失:
L
=
∑
t
=
0
T
L
t
L = \sum_{t=0}^{T} L_{t}
L=t=0∑TLt
假设目前只考虑初始的三个时间步长,我们选取最长一条反向传播通路为例,即以
L
3
L_{3}
L3对RNN中的参数求偏导:
∂
L
3
∂
W
O
=
∂
L
3
∂
O
3
∂
O
3
∂
W
O
\frac{\partial L_{3}}{\partial W_{O}} = \frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial W_{O}}
∂WO∂L3=∂O3∂L3∂WO∂O3
∂ L 3 ∂ W X = ∂ L 3 ∂ O 3 ∂ O 3 ∂ H 3 ∂ H 3 ∂ W X + ∂ L 3 ∂ O 3 ∂ O 3 ∂ H 3 ∂ H 3 ∂ H 2 ∂ H 2 ∂ W X + ∂ L 3 ∂ O 3 ∂ O 3 ∂ H 3 ∂ H 3 ∂ H 2 ∂ H 2 ∂ H 1 ∂ H 1 ∂ W X \frac{\partial L_{3}}{\partial W_{X}} = \frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial H_{3}} \frac{\partial H_{3}}{\partial W_{X}} + \frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial H_{3}} \frac{\partial H_{3}}{\partial H_{2}} \frac{\partial H_{2}}{\partial W_{X}} + \frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial H_{3}} \frac{\partial H_{3}}{\partial H_{2}} \frac{\partial H_{2}}{\partial H_{1}} \frac{\partial H_{1}}{\partial W_{X}} ∂WX∂L3=∂O3∂L3∂H3∂O3∂WX∂H3+∂O3∂L3∂H3∂O3∂H2∂H3∂WX∂H2+∂O3∂L3∂H3∂O3∂H2∂H3∂H1∂H2∂WX∂H1
∂ L 3 ∂ W H = ∂ L 3 ∂ O 3 ∂ O 3 ∂ H 3 ∂ H 3 ∂ W H + ∂ L 3 ∂ O 3 ∂ O 3 ∂ H 3 ∂ H 3 ∂ H 2 ∂ H 2 ∂ W H + ∂ L 3 ∂ O 3 ∂ O 3 ∂ H 3 ∂ H 3 ∂ H 2 ∂ H 2 ∂ H 1 ∂ H 1 ∂ W H \frac{\partial L_{3}}{\partial W_{H}} = \frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial H_{3}} \frac{\partial H_{3}}{\partial W_{H}} + \frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial H_{3}} \frac{\partial H_{3}}{\partial H_{2}} \frac{\partial H_{2}}{\partial W_{H}} + \frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial H_{3}} \frac{\partial H_{3}}{\partial H_{2}} \frac{\partial H_{2}}{\partial H_{1}} \frac{\partial H_{1}}{\partial W_{H}} ∂WH∂L3=∂O3∂L3∂H3∂O3∂WH∂H3+∂O3∂L3∂H3∂O3∂H2∂H3∂WH∂H2+∂O3∂L3∂H3∂O3∂H2∂H3∂H1∂H2∂WH∂H1
从中可以发现根据链式法则求导,
W
X
,
W
H
W_{X},W_{H}
WX,WH 会重复出现,这是由于
W
X
,
W
H
W_{X},W_{H}
WX,WH在每个时间步中都参与隐节点
H
H
H的状态估计,随着时间步数的增加,反向传播的路径将会相应的延长。相应的,我们可以总结出任意时刻
t
t
t损失函数对
W
X
W_{X}
WX的偏导:
∂
L
t
∂
W
X
=
∑
k
=
0
t
∂
L
t
∂
O
t
∂
O
t
∂
H
t
(
∏
j
=
k
+
1
t
∂
H
j
∂
H
j
−
1
)
∂
H
k
∂
W
X
\frac{\partial L_{t}}{\partial W_{X}} = \sum_{k=0}^{t} \frac{\partial L_{t}}{\partial O_{t}} \frac{\partial O_{t}}{\partial H_{t}} \left( \prod_{j=k+1}^{t} \frac{\partial H_{j}}{\partial H_{j-1}} \right) \frac{\partial H_{k}}{\partial W_{X}}
∂WX∂Lt=k=0∑t∂Ot∂Lt∂Ht∂Ot⎝⎛j=k+1∏t∂Hj−1∂Hj⎠⎞∂WX∂Hk
任意时刻
t
t
t损失函数对
W
H
W_{H}
WH求偏导同上,只需替换
W
X
W_{X}
WX。
如果隐藏层的激活函数为
t
a
n
h
tanh
tanh,则有
H
j
=
t
a
n
h
(
W
X
X
j
+
W
H
H
j
−
1
+
b
H
)
H_{j} = tanh(W_{X}X_{j}+W_{H}H_{j-1}+b_{H})
Hj=tanh(WXXj+WHHj−1+bH)
∏
j
=
k
+
1
t
∂
H
j
∂
H
j
−
1
=
∏
j
=
k
+
1
t
t
a
n
h
′
⋅
W
H
\prod_{j=k+1}^{t} \frac{\partial H_{j}}{\partial H_{j-1}} = \prod_{j=k+1}^{t} tanh^{'}·W_{H}
j=k+1∏t∂Hj−1∂Hj=j=k+1∏ttanh′⋅WH
激活函数
t
a
n
h
tanh
tanh的导数有如下特性:
t
a
n
h
′
x
=
1
−
t
a
n
h
2
x
tanh^{'} x= 1-tanh^{2}x
tanh′x=1−tanh2x
因此可知
0
<
t
a
n
h
′
≤
1
0< tanh^{'} \le 1
0<tanh′≤1。
在训练过程中,
H
j
H_{j}
Hj的状态极少情况下为0,
t
a
n
h
′
tanh^{'}
tanh′通常是小于1的,如果
W
H
W_{H}
WH值在
(
0
,
1
)
(0,1)
(0,1)范围内,则
∏
j
=
k
+
1
t
t
a
n
h
′
⋅
W
H
\prod_{j=k+1}^{t} tanh^{'}·W_{H}
∏j=k+1ttanh′⋅WH 趋近于0,导致梯度消失。
如果
W
H
W_{H}
WH值很大,则会导致
∏
j
=
k
+
1
t
t
a
n
h
′
⋅
W
H
\prod_{j=k+1}^{t} tanh^{'}·W_{H}
∏j=k+1ttanh′⋅WH 连乘后结果趋近于无穷,导致梯度爆炸。
如何避免这种现象?在RNN的课程学习中,提到裁剪梯度的方法,假设我们把所有模型参数的梯度拼接成一个向量
g
\boldsymbol{g}
g ,并设裁剪的阈值是
θ
\theta
θ 。裁剪后的梯度即:
min
(
θ
∥
g
∥
,
1
)
g
\min\left(\frac{\theta}{\|\boldsymbol{g}\|}, 1\right)\boldsymbol{g}
min(∥g∥θ,1)g
梯度的的
L
2
L_{2}
L2 范数不超过
θ
\theta
θ。函数如下:
def grad_clipping(params, theta, device):
norm = torch.tensor([0.0], device=device)
for param in params:
norm += (param.grad.data ** 2).sum()
norm = norm.sqrt().item()
if norm > theta:
for param in params:
param.grad.data *= (theta / norm)
裁剪梯度是一个简单好用的防止梯度爆炸的方法,实际上造成梯度衰减或梯度爆炸的根本原因是 ∏ j = k + 1 t ∂ H j ∂ H j − 1 \prod_{j=k+1}^{t} \frac{\partial H_{j}}{\partial H_{j-1}} ∏j=k+1t∂Hj−1∂Hj这一连乘项,理想的消除方法就是使 ∂ H j ∂ H j − 1 ∈ [ 0 , 1 ] \frac{\partial H_{j}}{\partial H_{j-1}} \in [0,1] ∂Hj−1∂Hj∈[0,1]其实这就是LSTM做的事情,至于细节如何则会在后续的篇幅中加以介绍。
二、LSTM
长短期记忆(Long short-term memory, LSTM)是一种特殊的RNN结构单元,主要是为了解决长序列训练过程中的梯度消失和梯度爆炸问题。简单来说,就是相比普通的RNN,LSTM能够在更长的序列中有更好的表现。
I
t
=
σ
(
X
t
W
x
i
+
H
t
−
1
W
h
i
+
b
i
)
I_t = σ(X_tW_{xi} + H_{t−1}W_{hi} + b_i)
It=σ(XtWxi+Ht−1Whi+bi)
F
t
=
σ
(
X
t
W
x
f
+
H
t
−
1
W
h
f
+
b
f
)
F_t = σ(X_tW_{xf} + H_{t−1}W_{hf} + b_f)
Ft=σ(XtWxf+Ht−1Whf+bf)
O
t
=
σ
(
X
t
W
x
o
+
H
t
−
1
W
h
o
+
b
o
)
O_t = σ(X_tW_{xo} + H_{t−1}W_{ho} + b_o)
Ot=σ(XtWxo+Ht−1Who+bo)
C
~
t
=
t
a
n
h
(
X
t
W
x
c
+
H
t
−
1
W
h
c
+
b
c
)
\widetilde{C}_t = tanh(X_tW_{xc} + H_{t−1}W_{hc} + b_c)
C
t=tanh(XtWxc+Ht−1Whc+bc)
C
t
=
F
t
⊙
C
t
−
1
+
I
t
⊙
C
~
t
C_t = F_t ⊙C_{t−1} + I_t ⊙\widetilde{C}_t
Ct=Ft⊙Ct−1+It⊙C
t
H
t
=
O
t
⊙
t
a
n
h
(
C
t
)
H_t = O_t⊙tanh(C_t)
Ht=Ot⊙tanh(Ct)
Y
t
=
ϕ
(
H
t
W
h
y
+
b
y
)
Y_{t} = \phi(H_{t}W_{hy}+b_{y})
Yt=ϕ(HtWhy+by)
相比RNN只有一个传递状态
H
t
H_{t}
Ht,LSTM有两个传输状态,一个
C
t
C_{t}
Ct (cell state),和一个
H
t
H_{t}
Ht (hidden state)。其中对于传递下去的
C
t
C_{t}
Ct 改变得很慢,通常输出的
C
t
C_{t}
Ct 是上一个状态传过来的
C
t
−
1
C_{t-1}
Ct−1 加上一些数值。而
H
t
H_{t}
Ht 则在不同节点下往往会有很大的区别。
⊙
\odot
⊙ 是Hadamard Product,也就是操作矩阵中对应的元素相乘,因此要求两个相乘矩阵是同型的。
⊕
\oplus
⊕ 则代表进行矩阵加法。
LSTM主要包括以下几个结构:
- 遗忘门:控制上一时间步的记忆细胞
- 输入门:控制当前时间步的输入
- 输出门:控制从记忆细胞到隐藏状态
- 记忆细胞:⼀种特殊的隐藏状态的信息的流动
三个门输出都经过诸如 s i g m o i d sigmoid sigmoid的激活函数,映射到 ( 0 , 1 ) (0,1) (0,1)的范围,形成门控状态。而记忆细胞则是通过 t a n h tanh tanh转换成 ( − 1 , 1 ) (-1,1) (−1,1)的范围,这里作为记忆细胞短期依赖输出而非门控信号。
LSTM 内部主要有三个阶段:
-
忘记阶段。这个阶段主要是对上一个节点传进来的输入进行选择性忘记。简单来说就是会 “忘记不重要的,记住重要的”。具体来说是通过计算得到的 F t F_{t} Ft (f表示forget)来作为忘记门控,来控制上一个状态的 C t − 1 C_{t-1} Ct−1 哪些需要留哪些需要忘。
-
选择记忆阶段。这个阶段将这个阶段的输入有选择性地进行“记忆”。主要是会对输入 X t X_{t} Xt进行选择记忆。哪些重要则着重记录下来,哪些不重要,则少记一些。当前的输入内容由前面计算得到的 C ~ t \widetilde{C}_t C t表示。而选择的门控信号则是由 I t I_{t} It (i代表information)来进行控制。
将上面两步得到的结果相加,即可得到传输给下一个状态的 C t C_{t} Ct 。也就是上图中的第一个公式。这里也使用 t a n h tanh tanh起到对输入的信息进行压缩的作用。
- 输出阶段。这个阶段将决定哪些将会被当成当前状态的输出。主要是通过 O t O_{t} Ot来进行控制的。并且还对上一阶段得到的 C t C_{t} Ct进行了放缩(通过一个tanh激活函数进行变化)。
与普通RNN类似,输出 Y t Y_{t} Yt往往最终也是通过 H t H_{t} Ht变化得到。
三、GRU
GRU(Gate Recurrent Unit)是循环神经网络(Recurrent Neural Network, RNN)的一种。和LSTM(Long-Short Term Memory)一样,也是为了解决长期记忆和反向传播中的梯度等问题而提出来的。
GRU和LSTM在很多情况下实际表现上相差无几,那么为什么我们要使用新人GRU(2014年提出)而不是相对经受了更多考验的LSTM(1997提出)呢?其实通过代码我们就可以发现,GRU的权重参数相比LSTM减少了1/4。相比LSTM,使用GRU能够达到相当的效果,并且相比之下更容易进行训练,能够很大程度上提高训练效率,因此很多时候会更倾向于使用GRU。
R
t
=
σ
(
X
t
W
x
r
+
H
t
−
1
W
h
r
+
b
r
)
R_{t} = σ(X_tW_{xr} + H_{t−1}W_{hr} + b_r)
Rt=σ(XtWxr+Ht−1Whr+br)
Z
t
=
σ
(
X
t
W
x
z
+
H
t
−
1
W
h
z
+
b
z
)
Z_{t} = σ(X_tW_{xz} + H_{t−1}W_{hz} + b_z)
Zt=σ(XtWxz+Ht−1Whz+bz)
H
~
t
−
1
=
R
t
⊙
H
t
−
1
\widetilde{H}_{t-1} = R_t ⊙H_{t−1}
H
t−1=Rt⊙Ht−1
H
~
t
=
t
a
n
h
(
X
t
W
x
h
+
H
~
t
−
1
W
h
h
+
b
h
)
\widetilde{H}_t = tanh(X_tW_{xh} + \widetilde{H}_{t-1}W_{hh} + b_h)
H
t=tanh(XtWxh+H
t−1Whh+bh)
H
t
=
Z
t
⊙
H
t
−
1
+
(
1
−
Z
t
)
⊙
H
~
t
H_t = Z_t⊙H_{t−1} + (1−Z_t)⊙\widetilde{H}_t
Ht=Zt⊙Ht−1+(1−Zt)⊙H
t
GRU很聪明的一点就在于,我们使用了同一个门控
Z
t
Z_t
Zt 就同时可以进行遗忘和选择记忆(LSTM则要使用多个门控)。
R
t
R_t
Rt在GRU中被称作重置⻔有助于捕捉时间序列⾥短期的依赖关系;
Z
t
Z_t
Zt被称作更新⻔有助于捕捉时间序列⾥⻓期的依赖关系。
Z
t
⊙
H
t
−
1
Z_t⊙H_{t−1}
Zt⊙Ht−1:表示对原本隐藏状态的选择性“遗忘”。这里的
Z
t
Z_t
Zt 可以想象成遗忘门(forget gate),忘记
H
t
−
1
H_{t−1}
Ht−1 维度中一些不重要的信息。
(
1
−
Z
t
)
⊙
H
~
t
(1−Z_t)⊙\widetilde{H}_t
(1−Zt)⊙H
t: 表示对包含当前节点信息的
H
~
t
\widetilde{H}_t
H
t 进行选择性”记忆“。与上面类似,这里的
(
1
−
Z
t
)
(1−Z_t)
(1−Zt)同理会忘记
H
~
t
\widetilde{H}_t
H
t维度中的一些不重要的信息。或者,这里我们更应当看做是对
H
~
t
\widetilde{H}_t
H
t 维度中的某些信息进行选择。
H
t
=
Z
t
⊙
H
t
−
1
+
(
1
−
Z
t
)
⊙
H
~
t
H_t = Z_t⊙H_{t−1} + (1−Z_t)⊙\widetilde{H}_t
Ht=Zt⊙Ht−1+(1−Zt)⊙H
t :结合上述,这一步的操作就是忘记传递下来的
H
t
−
1
H_{t−1}
Ht−1 中的某些维度信息,并加入当前节点输入的某些维度信息。
可以看到,这里的遗忘 Z t Z_t Zt和选择 ( 1 − Z t ) (1−Z_t) (1−Zt) 是联动的。也就是说,对于传递进来的维度信息,我们会进行选择性遗忘,则遗忘了多少权重 ( Z t Z_t Zt),我们就会使用包含当前输入的 H ~ t \widetilde{H}_t H t 中所对应的权重进行弥补 ( 1 − Z t ) (1−Z_t) (1−Zt)。以保持一种”恒定“状态。
四、深度循环神经网络
H
t
(
1
)
=
ϕ
(
X
t
W
x
h
(
1
)
+
H
t
−
1
(
1
)
W
h
h
(
1
)
+
b
h
(
1
)
)
\boldsymbol{H}_t^{(1)} = \phi(\boldsymbol{X}_t \boldsymbol{W}_{xh}^{(1)} + \boldsymbol{H}_{t-1}^{(1)} \boldsymbol{W}_{hh}^{(1)} + \boldsymbol{b}_h^{(1)})
Ht(1)=ϕ(XtWxh(1)+Ht−1(1)Whh(1)+bh(1))
H
t
(
ℓ
)
=
ϕ
(
H
t
(
ℓ
−
1
)
W
x
h
(
ℓ
)
+
H
t
−
1
(
ℓ
)
W
h
h
(
ℓ
)
+
b
h
(
ℓ
)
)
\boldsymbol{H}_t^{(\ell)} = \phi(\boldsymbol{H}_t^{(\ell-1)} \boldsymbol{W}_{xh}^{(\ell)} + \boldsymbol{H}_{t-1}^{(\ell)} \boldsymbol{W}_{hh}^{(\ell)} + \boldsymbol{b}_h^{(\ell)})
Ht(ℓ)=ϕ(Ht(ℓ−1)Wxh(ℓ)+Ht−1(ℓ)Whh(ℓ)+bh(ℓ))
O
t
=
H
t
(
L
)
W
h
q
+
b
q
\boldsymbol{O}_t = \boldsymbol{H}_t^{(L)} \boldsymbol{W}_{hq} + \boldsymbol{b}_q
Ot=Ht(L)Whq+bq
在pytorch的实现中以num_layers
参数进行层数的设置和调整。
五、双向循环神经网络
H
→
t
=
ϕ
(
X
t
W
x
h
(
f
)
+
H
→
t
−
1
W
h
h
(
f
)
+
b
h
(
f
)
)
\overrightarrow{\boldsymbol{H}}_t = \phi(\boldsymbol{X}_t \boldsymbol{W}_{xh}^{(f)} + \overrightarrow{\boldsymbol{H}}_{t-1} \boldsymbol{W}_{hh}^{(f)} + \boldsymbol{b}_h^{(f)})
Ht=ϕ(XtWxh(f)+Ht−1Whh(f)+bh(f))
H
←
t
=
ϕ
(
X
t
W
x
h
(
b
)
+
H
←
t
+
1
W
h
h
(
b
)
+
b
h
(
b
)
)
\overleftarrow{\boldsymbol{H}}_t = \phi(\boldsymbol{X}_t \boldsymbol{W}_{xh}^{(b)} + \overleftarrow{\boldsymbol{H}}_{t+1} \boldsymbol{W}_{hh}^{(b)} + \boldsymbol{b}_h^{(b)})
Ht=ϕ(XtWxh(b)+Ht+1Whh(b)+bh(b))
H
t
=
(
H
→
t
,
H
←
t
)
\boldsymbol{H}_t=(\overrightarrow{\boldsymbol{H}}_{t}, \overleftarrow{\boldsymbol{H}}_t)
Ht=(Ht,Ht)
O
t
=
H
t
W
h
q
+
b
q
\boldsymbol{O}_t = \boldsymbol{H}_t \boldsymbol{W}_{hq} + \boldsymbol{b}_q
Ot=HtWhq+bq
在pytorch的实现中以bidirectional
参数进行层数的设置和调整。