RNN Illustration
W,U,V参数在整个传播过程共享(参数共享)
RNN training process
一般神经网络的反向传播
反向传播过程
在针对每个C进行更新的时候,都需要反向传播到第一个时间步去计算对应的偏导数以更新参数,更详细的分析在下面。
梯度消失or爆炸
更详细的分析
上图中在求偏导过程中标注的long time dependency就是RNN出现梯度消失or爆炸的原因,归纳之后也就是
∏
j
=
k
+
1
t
∂
S
j
∂
S
j
−
1
\prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}
j=k+1∏t∂Sj−1∂Sj因为多个时间步相乘会造成结果无穷大or无穷小。
解决梯度消失or爆炸的方法
应用gate机制的网络:LSTM和GRU
要消除梯度消失or爆炸这种情况就需要把
∏
j
=
k
+
1
t
∂
S
j
∂
S
j
−
1
\prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}
∏j=k+1t∂Sj−1∂Sj在求偏导的过程中去掉。
LSTM
LSTM原理
原理讲解查看这里
下图显示了数据在LSTM单元内流动的过程
一些符号和公式
f
t
,
i
t
,
o
t
f_t, i_t, o_{t}
ft,it,ot是三个gate
LSTM解决梯度消失or爆炸
RNN中梯度消失or爆炸的原因在于
∏
j
=
k
+
1
t
∂
S
j
∂
S
j
−
1
\prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}
∏j=k+1t∂Sj−1∂Sj隐状态在反向传播时的依赖性,LSTM通过让这个递归式等于一个常数来解决这个问题。
首先LSTM引入一个单独的cell state
C
t
C_t
Ct,cell state遵从以下公式
C
t
=
f
C
t
−
1
+
i
C
~
t
C_{t}=f C_{t-1}+i \widetilde{C}_{t}
Ct=fCt−1+iC
t
下面直讲重点,全部详细讲解见这里
当我们做bp时,LSTM的recursive derivative为
∂
C
t
∂
C
t
−
1
\frac{\partial C_t}{\partial C_{t-1}}
∂Ct−1∂Ct, 根据上面的公式和符号,可以得知
C
t
C_{t}
Ct是
f
t
,
i
t
,
C
~
t
f_t, i_t, \widetilde{C}_{t}
ft,it,C
t的函数,而
f
t
,
i
t
,
C
~
t
f_t, i_t, \widetilde{C}_{t}
ft,it,C
t是
C
t
−
1
C_{t-1}
Ct−1的函数(因为
h
t
=
o
t
tanh
(
C
t
)
h_{t}=o_{t} \tanh \left(C_{t}\right)
ht=ottanh(Ct))
完整的bp
∂
C
t
∂
C
t
−
1
=
∂
C
t
∂
f
t
∂
f
t
∂
h
t
−
1
∂
h
t
−
1
∂
C
t
−
1
+
∂
C
t
∂
i
t
∂
i
t
∂
h
t
−
1
∂
h
t
−
1
∂
C
t
−
1
+
∂
C
t
∂
C
~
t
∂
C
~
t
∂
h
t
−
1
∂
h
t
−
1
∂
C
t
−
1
+
∂
C
t
∂
C
t
−
1
\begin{aligned} \frac{\partial C_{t}}{\partial C_{t-1}} &=\frac{\partial C_{t}}{\partial f_{t}} \frac{\partial f_{t}}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial C_{t-1}}+\frac{\partial C_{t}}{\partial i_{t}} \frac{\partial i_{t}}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial C_{t-1}} \\ &+\frac{\partial C_{t}}{\partial \widetilde{C}_{t}} \frac{\partial \widetilde{C}_{t}}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial C_{t-1}}+\frac{\partial C_{t}}{\partial C_{t-1}} \end{aligned}
∂Ct−1∂Ct=∂ft∂Ct∂ht−1∂ft∂Ct−1∂ht−1+∂it∂Ct∂ht−1∂it∂Ct−1∂ht−1+∂C
t∂Ct∂ht−1∂C
t∂Ct−1∂ht−1+∂Ct−1∂Ct
根据上面的公式写出完整的表达式:
∂
C
t
∂
C
t
−
1
=
C
t
−
1
σ
′
(
⋅
)
W
f
∗
o
t
−
1
tanh
′
(
C
t
−
1
)
+
C
~
t
σ
′
(
⋅
)
W
i
∗
o
t
−
1
tanh
′
(
C
t
−
1
)
+
i
t
tanh
′
(
⋅
)
W
C
∗
o
t
−
1
tanh
′
(
C
t
−
1
)
+
f
t
\begin{aligned} \frac{\partial C_{t}}{\partial C_{t-1}} &=C_{t-1} \sigma^{\prime}(\cdot) W_{f} * o_{t-1} \tanh ^{\prime}\left(C_{t-1}\right) \\ &+\widetilde{C}_{t} \sigma^{\prime}(\cdot) W_{i} * o_{t-1} \tanh ^{\prime}\left(C_{t-1}\right) \\ &+i_{t} \tanh ^{\prime}(\cdot) W_{C} * o_{t-1} \tanh ^{\prime}\left(C_{t-1}\right) \\ &+f_{t} \end{aligned}
∂Ct−1∂Ct=Ct−1σ′(⋅)Wf∗ot−1tanh′(Ct−1)+C
tσ′(⋅)Wi∗ot−1tanh′(Ct−1)+ittanh′(⋅)WC∗ot−1tanh′(Ct−1)+ft
RNN中, ∏ j = k + 1 t ∂ S j ∂ S j − 1 \prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}} ∏j=k+1t∂Sj−1∂Sj的结果最终会大于1or在0-1之间,这会导致梯度爆炸or消失。
但LSTM中,每个时间步中的 ∂ C t ∂ C t − 1 \frac{\partial C_t}{\partial C_{t-1}} ∂Ct−1∂Ct既可以大于1也可以取0-1(通过调整 f t f_t ft,提高则让梯度接近1),于是最终不会造成梯度问题。
值得注意的是: f t , i t , C ~ t f_t, i_t, \widetilde{C}_{t} ft,it,C t的值会被网络自动设置,因此,网络会学习在什么时候通过设置gata保留或消除梯度。
GRU
比LSTM有更少的参数,通过
r
t
r_t
rt控制是否遗忘上一个时间步的信息