一、RNN基本结构
1、隐层状态 s t s_t st
s t = σ ( U x t + W s t − 1 + b 1 ) s_t=\sigma(Ux_t+Ws_{t-1}+b_1) st=σ(Uxt+Wst−1+b1) σ \sigma σ()是激活函数,通常选用Tanh、ReLU。
2、输出状态 o t o_t ot
o t = g ( V s t + b 2 ) o_t=g(Vs_t+b_2) ot=g(Vst+b2) g g g()是激活函数,对于分类任务通常选用 s i g m o i d sigmoid sigmoid()。
3、Loss计算
输出状态
o
t
o_t
ot与目标输出
y
t
y_t
yt计算Loss:
L
=
∑
t
L
t
=
∑
t
L
o
s
s
(
o
t
,
y
t
)
L=\sum_{t}L_t=\sum_{t}Loss(o_t,y_t)
L=t∑Lt=t∑Loss(ot,yt)
L
o
s
s
Loss
Loss是损失函数,对于分类任务通常选用交叉熵损失函数。
二、RNN参数更新方式
1、首先需要明确:上述的循环重复结构,都是共享参数的,也就是说不管在什么时刻,权重矩阵 U U U、 W W W、 V V V都是相同的。
好处:极大减少参数量+可以处理不定长序列。
2、梯度下降、反向传播过程
假设
t
=
3
t=3
t=3的时刻,计算它的损失函数:
s
3
=
σ
(
U
x
3
+
W
s
2
+
b
1
)
o
3
=
g
(
V
s
3
+
b
2
)
L
3
=
1
2
(
o
3
−
y
3
)
2
s_3=\sigma(Ux_3+Ws_{2}+b_1) \\ o_3=g(Vs_3+b_2) \\ L_3=\frac{1}{2}(o_3-y_3)^2
s3=σ(Ux3+Ws2+b1)o3=g(Vs3+b2)L3=21(o3−y3)2那么求偏导的时候:
∂
L
3
∂
V
=
∂
L
3
∂
o
3
∂
o
3
∂
V
\frac{ \partial L_3 }{ \partial V}=\frac{ \partial L_3 }{ \partial o_3}\frac{ \partial o_3 }{ \partial V}
∂V∂L3=∂o3∂L3∂V∂o3
∂
L
3
∂
U
=
∂
L
3
∂
o
3
∂
o
3
∂
s
3
∂
s
3
∂
U
+
∂
L
3
∂
o
3
∂
o
3
∂
s
3
∂
s
3
∂
s
2
∂
s
2
∂
U
+
∂
L
3
∂
o
3
∂
o
3
∂
s
3
∂
s
3
∂
s
2
∂
s
2
∂
s
1
∂
s
1
∂
U
\frac{ \partial L_3 }{ \partial U}=\frac{ \partial L_3 }{ \partial o_3}\frac{ \partial o_3 }{ \partial s_3} \frac{ \partial s_3 }{ \partial U}+\frac{ \partial L_3 }{ \partial o_3}\frac{ \partial o_3 }{ \partial s_3} \frac{ \partial s_3 }{ \partial s_2}\frac{ \partial s_2 }{ \partial U}+\frac{ \partial L_3 }{ \partial o_3}\frac{ \partial o_3 }{ \partial s_3} \frac{ \partial s_3 }{ \partial s_2}\frac{ \partial s_2 }{ \partial s_1}\frac{ \partial s_1 }{ \partial U}
∂U∂L3=∂o3∂L3∂s3∂o3∂U∂s3+∂o3∂L3∂s3∂o3∂s2∂s3∂U∂s2+∂o3∂L3∂s3∂o3∂s2∂s3∂s1∂s2∂U∂s1
∂
L
3
∂
W
=
∂
L
3
∂
o
3
∂
o
3
∂
s
3
∂
s
3
∂
W
+
∂
L
3
∂
o
3
∂
o
3
∂
s
3
∂
s
3
∂
s
2
∂
s
2
∂
W
+
∂
L
3
∂
o
3
∂
o
3
∂
s
3
∂
s
3
∂
s
2
∂
s
2
∂
s
1
∂
s
1
∂
W
\frac{ \partial L_3 }{ \partial W}=\frac{ \partial L_3 }{ \partial o_3}\frac{ \partial o_3 }{ \partial s_3} \frac{ \partial s_3 }{ \partial W}+\frac{ \partial L_3 }{ \partial o_3}\frac{ \partial o_3 }{ \partial s_3} \frac{ \partial s_3 }{ \partial s_2}\frac{ \partial s_2 }{ \partial W}+\frac{ \partial L_3 }{ \partial o_3}\frac{ \partial o_3 }{ \partial s_3} \frac{ \partial s_3 }{ \partial s_2}\frac{ \partial s_2 }{ \partial s_1}\frac{ \partial s_1 }{ \partial W}
∂W∂L3=∂o3∂L3∂s3∂o3∂W∂s3+∂o3∂L3∂s3∂o3∂s2∂s3∂W∂s2+∂o3∂L3∂s3∂o3∂s2∂s3∂s1∂s2∂W∂s1因为
s
3
s_3
s3是由前面的
s
1
s_1
s1、
s
2
s_2
s2递推出来的,所以
L
L
L对
U
U
U、
W
W
W求偏导的公式需要把前面的
s
1
s_1
s1、
s
2
s_2
s2带入进去:
s
3
=
σ
(
U
x
3
+
W
s
2
+
b
1
)
=
σ
(
U
x
3
+
W
(
σ
(
U
x
2
+
W
s
1
+
b
1
)
)
+
b
1
)
s_3=\sigma(Ux_3+Ws_{2}+b_1)\\ =\sigma(Ux_3+W(\sigma(Ux_2+Ws_{1}+b_1))+b_1)
s3=σ(Ux3+Ws2+b1)=σ(Ux3+W(σ(Ux2+Ws1+b1))+b1)由此能知道,时间序列越长,出现连乘的部分会越集中在后面。也就是通过时间的反向传播。
三、RNN和普通神经网络梯度消失的本质区别
普通神经网络:它不是按时间步进行反向传播的,因此不会有一项一项相加的部分,只有一个总体的连乘求偏导过程。它的梯度消失是总的梯度会趋于0的。
RNN:每一项一项进行相加,可以发现距离拉的越长,连乘的项就越多,远距离的梯度会趋于0的,近距离的梯度不会消失。RNN梯度消失的真正含义是总的梯度受近距离梯度的主导,远距离的梯度消失。
四、RNN梯度消失梯度爆炸及解决方式
梯度爆炸:采用梯度截断的方式
梯度消失:1、采用跨时域的残差连接 。 2、采用门控机制(LSTM、GRU)作为RNN基本单元,控制信息流入量