Recurrent Neural Network
RNN 又叫做递归神经网络或者循环神经网络,它擅长对序列数据进行建模处理,如时间序列数据,是指在不同时间点上收集的数据,这类数据反映了某一事物、现象在随时间的变化状态或程度,当然这是时间,也可以是文本或图像序列,总的来说,序列数据存在着一个特点——后面的数据跟前面的数据有关系
为什么需要 RNN ?
神经网络结构只能单独的处理一个个的输入,前一个输入与后一个输入是完成没有关系的,但是某些任务需要更好的处理序列信息,即前面一个输入和后面一个输入是要有关系的,通俗点来说就是后一个输入需要记忆前面一个输入的信息
比如,当我们在理解一句话时,孤立的理解单个词是没有意义的,只有将上下词联系起来的整个序列才具有意义;当我们处理视频时,也不能单独分析每一帖,需要分析这些帧连接起来的整个序列
为了解决这一类问题,能够更好的处理序列信息,RNN 模型就应运而生,那么 RNN 又是怎样实现这样的功能呢?
RNN 的结构
RNN 主要对序列数据进行序列处理,其基本结构如下图所示:
上图是 RNN 的结构示意图,每个箭头表示着一次变换,也就是说箭头带有权值,左侧是折叠起来的样子,右侧是展开的样子,左侧 A 旁边的箭头体现着结构中的 “循环” 概念。
在右侧展开结构中我们可以看到,在 x 0 x_0 x0 作为输入时,该单元的输出分为二个方向,向上的 h o h_o ho 表示的是其作为一个输出,向右箭头表示的是其另一个输出作为下一个单元的输入,以此达到与下一个单元之间保持着某个联系,即记忆功能
为了更好的理解,我们看下图:
简单点来说就是:当在 x t x_t xt 时刻时,该单元的输入就分为二个: S t − 1 S_{t-1} St−1 、 x t x_t xt, 输入也分为二个: S t S_{t} St 、 O t O_t Ot
- S t − 1 S_{t-1} St−1 表示的是 x t − 1 x_{t-1} xt−1 时刻的一个输出
- x t x_t xt 表示本时刻的一个输入
- S t S_{t} St 表示 x t x_t xt 时刻的一个输出,将作为下一时刻的一个输入
- O t O_t Ot 表示 x t x_t xt 时刻的输出
我们可以用下面的公式来表示 RNN 的计算方式:
- 上图同样与展现出了 RNN 的另一个特点:权值共享,其中 U 是完全相同的, W、V也是一样的
那么我们再来看看 隐藏层 S 中究竟发生了怎样的变化
我们可以看到 h t − 1 h_{t-1} ht−1 和 x t x_t xt 之间实际上是做了一个 ocncatenate 操作,然后再经过激活函数最终形成了一个输出,值得注意的是它的一个维度变化
Bidirectional RNNs 双向循环神经网络
基本的 RNN 结构只能从之前时间步骤中学习,但是有时我们却需要从未来的时间步骤中学习表示,以便更好地理解上下文环境并消除歧义,通过接下来的列子,“He said, Teddy bears are on sale” and “He said, Teddy Roosevelt was a great President。在上面的两句话中,当我们看到“Teddy”和前两个词“He said”的时候,我们有可能无法理解这个句子是指President还是Teddy bears。因此,为了解决这种歧义性,我们需要往前查找。这就是双向RNN所能实现的。
如图所求,双向 RNN 有二种类型的连接,一种是前向的(Foward RNN),这有助于我们从之前的表示中学习, 另一种是后向的(Backward RNN),这有助于我们从之后的表示中学习
正向传播分为二个步骤:
-
我们先从左向右移动,从初始时间步骤开始计算,一直持续到到达最终时间步骤为止
-
再从右向左移动,从最后一个时间步骤开始计算,一直持续到到达最终时间步骤为止
- 一般来说是从前往向计算,再从后往前计算,计算过程相互独立,互不干扰
计算预测输出值就变成了:
y
^
<
t
>
=
g
(
W
y
[
a
→
<
t
>
,
a
←
<
t
>
]
)
\hat{y}^{<t>}= g(W_y[\overrightarrow{a}^{<t>},\overleftarrow{a}^{<t>}])
y^<t>=g(Wy[a<t>,a<t>])
a
→
<
t
>
\overrightarrow{a}^{<t>}
a<t>表示 Forward RNN 的激活函数,
a
←
<
t
>
\overleftarrow{a}^{<t>}
a<t> 表示 Backward RNN 的激活函数,箭头方向表示的传递方向
梯度消失和梯度爆炸
误差梯度在网络训练中用来得到网络参数的方向和步幅,在正确的方向下以合适的步幅更新网络参数。
梯度爆炸:在递归神经网络中,误差梯度会在更新中累积得到一个非常大的梯度,这样的梯度会大幅更新网络参数,导致网络的不稳定,在极端情况下,权值会变得非常的大以至于结果会溢出(NaN值、无穷或非数值),当梯度爆炸发生时,网络层之间反复乘以大于1.0的值使得梯度值成倍增长
梯度更新:如果误差梯度在更新中累积得到一个非常小的梯度,这也就意味着权值无法更新,最终导致训练失败
利用公式分析原因
经典 RNN 的结构如下图所求:
关于向前传播:
假设我们的时间序列只有三段,
S
0
S_0
S0 为定值,神经元没有激活函数(便于分析)就可获得各个时间段的状态和输出:
t
=
1
时
刻
S
1
=
U
X
1
+
W
S
0
+
b
1
O
1
=
V
S
1
+
b
2
\begin{aligned}&t = 1 \text{ }时刻\\&S_1 = UX_1 + WS_0 + b_1\\&O_1 = VS_1 + b_2\end{aligned}\\
t=1 时刻S1=UX1+WS0+b1O1=VS1+b2
t = 2 时 刻 S 2 = U X 2 + W S 1 + b 1 O 2 = V S 2 + b 2 \begin{aligned}&t = 2 \text{ }时刻\\&S_2 = UX_2 + WS_1 + b_1\\&O_2 = VS_2 + b_2\end{aligned}\\ t=2 时刻S2=UX2+WS1+b1O2=VS2+b2
t = 3 时 刻 S 3 = U X 3 + W S 2 + b 1 O 3 = V S 3 + b 2 \begin{aligned}&t = 3 \text{ }时刻\\&S_3 = UX_3 + WS_2 + b_1\\&O_3 = VS_3 + b_2\end{aligned}\\ t=3 时刻S3=UX3+WS2+b1O3=VS3+b2
损失函数采用交叉熵
L
t
=
−
O
t
‾
l
o
g
O
t
L_t=-\overline{O_t}logO_t
Lt=−OtlogOt (
O
t
O_t
Ot是 t 时刻的预测输出,
O
t
‾
\overline{O_t}
Ot是 t 时刻的真实输出),那么对于一次训练任务中,损失函数为:
L
=
∑
i
=
1
T
−
O
t
‾
l
o
g
O
t
L = \sum_{i=1}^{T}-\overline{O_t}logO_t
L=i=1∑T−OtlogOt
T 是序列总长度,上述公式为每一时刻损失值的累加
关于反射传播:
我们只对 t3 时时刻的 U、V、W 求偏导,由链式法则可得:
∂
L
3
∂
V
=
∂
L
3
∂
O
3
∂
O
3
∂
V
∂
L
3
∂
W
=
∂
L
3
∂
O
3
∂
O
3
∂
S
3
∂
S
3
∂
W
+
∂
L
3
∂
O
3
∂
O
3
∂
S
2
∂
S
2
∂
W
+
∂
L
3
∂
O
3
∂
O
3
∂
S
3
∂
S
3
∂
S
2
∂
S
2
∂
S
1
∂
S
1
∂
W
∂
L
3
∂
U
=
∂
L
3
∂
O
3
∂
O
3
∂
S
3
∂
S
3
∂
U
+
∂
L
3
∂
O
3
∂
O
3
∂
S
2
∂
S
2
∂
U
+
∂
L
3
∂
O
3
∂
O
3
∂
S
3
∂
S
3
∂
S
2
∂
S
2
∂
S
1
∂
S
1
∂
U
\begin{aligned}&\frac{\partial{L_3}}{\partial{V}} = \frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{V}}\\&\frac{\partial{L_3}}{\partial{W}} = \frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_3}}\frac{\partial{S_3}}{\partial{W}} + \frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_2}}\frac{\partial{S_2}}{\partial{W}} + \frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_3}}\frac{\partial{S_3}}{\partial{S_2}}\frac{\partial{S_2}}{\partial{S_1}}\frac{\partial{S_1}}{\partial{W}}\\&\frac{\partial{L_3}}{\partial{U}} = \frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_3}}\frac{\partial{S_3}}{\partial{U}} + \frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_2}}\frac{\partial{S_2}}{\partial{U}} + \frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_3}}\frac{\partial{S_3}}{\partial{S_2}}\frac{\partial{S_2}}{\partial{S_1}}\frac{\partial{S_1}}{\partial{U}}\end{aligned}
∂V∂L3=∂O3∂L3∂V∂O3∂W∂L3=∂O3∂L3∂S3∂O3∂W∂S3+∂O3∂L3∂S2∂O3∂W∂S2+∂O3∂L3∂S3∂O3∂S2∂S3∂S1∂S2∂W∂S1∂U∂L3=∂O3∂L3∂S3∂O3∂U∂S3+∂O3∂L3∂S2∂O3∂U∂S2+∂O3∂L3∂S3∂O3∂S2∂S3∂S1∂S2∂U∂S1
可以简写成:
∂
L
3
∂
U
=
∑
k
=
0
3
∂
L
3
∂
O
3
∂
O
3
∂
S
3
∂
S
3
∂
S
k
∂
S
k
∂
U
=
∑
k
=
0
3
∂
L
3
∂
O
3
∂
O
3
∂
S
3
(
∏
j
=
k
−
1
3
∂
S
j
∂
S
j
−
1
)
∂
S
k
∂
U
任
意
时
刻
对
参
数
W
求
偏
导
的
公
式
:
∂
L
3
∂
W
=
∑
k
=
0
t
∂
L
3
∂
O
3
∂
O
3
∂
S
3
(
∏
j
=
k
−
1
t
∂
S
j
∂
S
j
−
1
)
∂
S
k
∂
w
\begin{aligned}&\frac{\partial{L_3}}{\partial{U}} = \sum_{k=0}^{3}\frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_3}}\frac{\partial{S_3}}{\partial{S_k}}\frac{\partial{S_k}}{\partial{U}}= \sum_{k=0}^{3}\frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_3}}\left( \prod_{j=k-1}^{3}\frac{\partial{S_j}}{\partial{S_{j-1}}} \right)\frac{\partial{S_k}}{\partial{U}}\\&任意时刻对参数 W 求偏导的公式:\\&\frac{\partial{L_3}}{\partial{W}} =\sum_{k=0}^{t}\frac{\partial{L_3}}{\partial{O_3}}\frac{\partial{O_3}}{\partial{S_3}}\left( \prod_{j=k-1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}} \right)\frac{\partial{S_k}}{\partial{w}}\\\end{aligned}
∂U∂L3=k=0∑3∂O3∂L3∂S3∂O3∂Sk∂S3∂U∂Sk=k=0∑3∂O3∂L3∂S3∂O3⎝⎛j=k−1∏3∂Sj−1∂Sj⎠⎞∂U∂Sk任意时刻对参数W求偏导的公式:∂W∂L3=k=0∑t∂O3∂L3∂S3∂O3⎝⎛j=k−1∏t∂Sj−1∂Sj⎠⎞∂w∂Sk
由此可以看出 V 求偏导不存在依赖关系,而 W、U则随时间长度存在着长期的依赖关系,因为 S t S_t St 会随着时间序列向前传播,而同时 S t S_t St 是 U、W 的函数
如果取其中的累乘出来,其中激活函数通常是:tanh = [0, 1] 则:
∏
j
=
k
−
1
t
∂
S
j
∂
S
j
−
1
=
∏
j
=
k
−
1
t
t
a
n
h
′
W
\prod_{j=k-1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}} = \prod_{j=k-1}^{t}tanh^{'}W
j=k−1∏t∂Sj−1∂Sj=j=k−1∏ttanh′W
- 由上图可以看出 t a n h ′ ∈ [ 0 , 1 ] tanh^{'}\in [0, 1] tanh′∈[0,1] , 也就是说大部分都是 小于1的数在做累乘,假设 W 也是一个大于0小于1的值时,当 t 很大时, ∏ j = k − 1 t t a n h ′ W 公 式 中 的 ∏ j = k − 1 t t a n h ′ \prod_{j=k-1}^{t}tanh^{'}W 公式中的 \prod_{j=k-1}^{t}tanh^{'} ∏j=k−1ttanh′W公式中的∏j=k−1ttanh′ 部分会趋向于 0,这就是 RNN 中梯度消失的原因
- 同理, ∏ j = k − 1 t t a n h ′ W 公 式 中 的 \prod_{j=k-1}^{t}tanh^{'}W 公式中的 ∏j=k−1ttanh′W公式中的 W 参数很大时,结果就会趋于无穷,这就是产生 梯度爆炸 的原因
解决办法
面对梯度爆炸的问题,我们可以看到梯度爆炸是因为 W 参数的值过大,而 W 值随着序列长度存在长期的依赖关系,因而我们可以设置一个上限值,一旦超过上限值,就等于我们的预设值,这样就可以解决梯度爆炸的问题了
面对梯度消失的问题,梯度消失的原因是 ∏ j = k − 1 t ∂ S j ∂ S j − 1 \prod_{j=k-1}^{t}\frac{\partial{S_j}}{\partial{S_{j-1}}} ∏j=k−1t∂Sj−1∂Sj 求导而产生的,因此想要消除这种情况就需要在求领导的时候去掉就行了,那么怎样去掉呢,一般有二种方法:
- 使 ∂ S j ∂ S j − 1 ≈ 1 \frac{\partial{S_j}}{\partial{S_{j-1}}} \approx 1 ∂Sj−1∂Sj≈1,那么怎样达到这种目标呢?答案是换一种激活函数,我们来看一下 ReLu 作为激活函数的效果:
可以看到 ReLu 导数在定义域大于0的部分是恒等于1,这样就可以解决梯度消失的问题了
- 使 ∂ S j ∂ S j − 1 ≈ 0 \frac{\partial{S_j}}{\partial{S_{j-1}}} \approx 0 ∂Sj−1∂Sj≈0,我们可以采用 LSTM 可以达到这样的效果,那么 LSTM 又是怎样实现的呢,我们在下一篇文章中再来详细解决
参考文献:
[1]. https://www.jiqizhixin.com/articles/2019-01-17-7
声明:
以上内容为个人理解,若有错误,请各位大佬指出,以便大家多作交流!