1. 写在前面
最近用深度学习做一些时间序列预测的实验, 用到了一些循环神经网络的知识, 而当初学这块的时候,只是停留在了表面,并没有深入的学习和研究,只知道大致的原理, 并不知道具体的细节,所以导致现在复现一些经典的神经网络会有困难, 所以这次借着这个机会又把RNN, GRU, LSTM以及Attention的一些东西复习了一遍,真的是每一遍学习都会有新的收获,之前学习过也没有整理, 所以这次也借着这个机会把这一块的基础内容进行一个整理和总结, 顺便了解一下这些结构底层的逻辑。
当然,这次的整理是查缺补漏, 类似于知识的串联, 一些很基础的内容可能不会涉及到, 这一部分由于篇幅很长,所以打算用三篇基础文章来整理, 分别是重温循环神经网络RNN, 重温LSTM和GRU和重温Seq2Seq与Attention机制。 整理完了这些基础知识 , 然后会总结一篇用于时间序列预测非线性自回归模型的论文,这篇论文用的就是带有双阶段注意力机制的LSTM,后面也会使用keras尝试复现并用于时间序列预测的任务,通过这样的方式,可以把这些基础知识从理论变成实践, 这也是先整理三篇基础文章的原因, 因为复现过程中发现一些细节懵懵懂懂, 所以还是先温习遍基础 😉
第一篇就是重温RNN, 这里面会先从全连接神经网络开始, 看一下神经网络到底长什么样子以及如何进行计算, 然后针对一些特殊的任务说一下全连接神经网络的局限性引出循环神经网络架构, 然后根据这个结构说一些基础知识和运算细节, 并用numpy简单实现一下RNN的前向传播过程, 最后分析一下传统RNN的局限性, 通过反向传播的公式看一下RNN为什么会存在梯度消失和爆炸, 而有了梯度消失为啥又不能捕捉到长期关联, 如何解决梯度消失问题等。 通过解决方法引出LSTM和变体GRU, 再去探索这两个的原理和一些实现细节。
大纲如下:
- 理解RNN? 我们先从全连接神经网络开始
- 关于RNN结构的基础知识和计算细节
- RNN前向传播的numpy实现与RNN的局限性
- 总结
Ok, let’s go!
2. RNN? 我们还是先从神经网络开始说起吧
说到神经网络, 我们肯定是不陌生了, 并且也非常熟悉它的运算过程, 拿我整理Pytorch的时候的一张图再回顾一下神经网络:
上面其实就是全连接网络的一个总体计算过程,左上就是一个全连接神经网络示意图, 一个全连接神经网络有一个输入层, 若干个隐藏层和一个输出层, 它的计算步骤包括前向传播, 计算损失, 反向传播, 更新参数,然后重复这个过程。 具体细节就不再这里展开了, 这种网络功能也是非常强大, 由于激活函数的存在,也善于学习很复杂的非线性关系。
但是有些任务, 比如我们的输入是一个句子: Cat is beautiful! 让这个神经网络进行翻译, 我们一般要这么做, 首先,会把上面这3个单词转成向量的形式,要不然模型不认识, 可以通过one-hot或者embedding等, 然后我们喂入神经网络, 得到输出:
应该是一个这样的过程, 上面这个图得好好理解一下, 这就是如果基于全连接网络的话会是这样的一个图, 这里之所以画成3步,就是为了后面更好的理解循环神经网络, 如果看过吴恩达老师的深度学习, 这里画的是这样的一个图:
这里也拿来做个对比吧, 这个图的话很容易把特征和不同时间步的序列给搞混了, 并且不利于后面和递归神经网络进行对比,所以我把每个单词的翻译给分开了, 分别通过神经网络进行翻译。
但是上面这种网络存在一些问题, 很大的一个问题就是单词和单词之间的翻译孤立起来了, 没有关联了, 但是我们知道句子翻译很大程度上是依赖于上下文的, 如果不看上下文, 很容易把某个词翻译错的。比如我前面的cat换成cats, 后面的is就需要换成are, 但是在上面的神经网络里面, 是学习不到这种词与词之间的关联关系, 所以这种神经网络对于这种时序性的任务不擅长, 也就是说如果我的输入是一串序列,并且这串序列前后之间有关联关系, 比如一个句子, 一段音乐, 一段语音,一段视频, 一段随时间变化的数据(股票,温度)等这样的数据, 如果想用一个网络对这样的数据进行建模, 比如捕捉这些前后的关联关系,全连接神经网络是不行的,什么? 还有CNN?CNN1D确实可以处理一些简单的时间序列数据,卷积神经网络也确实可以采用滑动窗口的那种思想去捕捉局部的一些特征, 但是功能比较受限(长距离依赖没法学习), 于是循环神经网络诞生了。
3. 关于RNN结构的基础知识和计算细节
啥叫循环神经网络呢? 这里的循环到底干什么事情呢? 下面这个就是循环神经网络的图, 通过这个图很容易看到循环吧, 但是对于初学者来说,这个图并不是那么好理解:
其实,虽然这个图不是那么好理解, 那还是这个图能够真正的表示循环神经网络,更能看出一种循环, 简单的说, 循环神经网络在做一件这样的事情:
我们的输入序列不是说有时间的先后关系吗?我们不是说要捕捉不同时间步中输入数据的关联吗? 看看RNN是如何做的:
我们不妨设t-1, t, t+1三个时刻, 首先神经网络会接收t-1时刻的输入 X t − 1 X_{t-1} Xt−1进行运算, 然后求出隐藏状态 S t − 1 S_{t-1} St−1和输出 O t − 1 O_{t-1} Ot−1, 计算完毕之后, 会把隐藏状态的值 S t − 1 S_{t-1} St−1和t时刻的输入 X t X_{t} Xt同时作为t时刻的神经网络的运算输入, 然后进行计算得到 S t S_{t} St和 O t O_{t} Ot, 计算完毕之后, 把t时刻的隐态 S t S_t St与t+1时刻的输入 X t + 1 X_{t+1} Xt+1作为t+1时刻神经网络的输入, 计算 S t + 1 和 O t + 1 S_{t+1}和O_{t+1} St+1和Ot+1, 这个过程是一气呵成的, 之所以称之为循环,就是因为它需要在多个时间步中反复执行这个计算过程, 而后面时间步里面的计算,需要用到前面时间步中的结果, 通过这种方式去捕捉序列之间的关联关系。
下面看两张动图感受一下这个过程:
第一个过程, 每个时间步接收一个输入, 并进行计算处理
第二个过程, 前一时间步处理的结果要传递到下一个时间步
所以上面这个过程我们可以用下面的公式表示:
O
t
=
g
(
V
⋅
S
t
+
b
o
)
S
t
=
f
(
U
⋅
X
t
+
W
⋅
S
t
−
1
+
b
s
)
\begin{array}{l} O_{t}=g\left(V \cdot S_{t}+ b_o\right ) \\ S_{t}=f\left(U \cdot X_{t}+W \cdot S_{t-1}+b_s\right) \end{array}
Ot=g(V⋅St+bo)St=f(U⋅Xt+W⋅St−1+bs)
也就是当前时刻t的隐藏状态
S
t
S_{t}
St不仅仅取决于当前的输入
X
t
X_t
Xt, 还取决于前一个时刻的隐藏状态值
S
t
−
1
S_{t-1}
St−1, 这里的
g
,
f
g,f
g,f激活函数了。看下图可能会更加清楚:
如果是把我上面举得那个例子拿下来的话,就是这样的一个感觉
所以这里要注意一些细节:
- 不要以为这是很多个全连接神经网络,其实这就是一个神经网络,只不过不同的时间步用了不同的输入而已。
- 这里的前向传播过程是一气呵成的, 就是在一个时间步的循环中,直接进行每个时间步的前向传播,得到最后的结果。
- 注意这里的可学习参数 W , V , U W, V, U W,V,U, 不同的时间步里面都使用的这一套参数, 所以这里的参数是共享的, 参数共享有很多好处, 比如减少计算量, 比如特征提取, 也可以让模型更好的泛化, 比如我去年去了北京, 和去年我去了北京, 这两个句子意思一样, 但是文字位置不同,共享的参数有利于学习词义本身而不是每个位置的规则。
- 这里还要注意几个名词, 第一个就是timesteps, 表示时间步长, 也就是时间序列的长度, 需要循环迭代的次数, 第二个就是input_dim, 这个表示的每个时间步的输入数据有多少个特征, 第三个是units, 这个指的是上面隐藏层有多少个神经单元, 为什么要说这三个名词呢? 因为在使用实际用RNN或者LSTM的时候,这三个是核心参数,后面整理LSTM的时候,会看看keras的LSTM层如何用,那时候会再次看到这三个名词
下面我们把上面按照时间线展开的RNN换一种形式表示,就是把那个圆圈给它再放大放大,进来看看细节:
这个就是RNN按照时间线展开的图了,这里的符号可能和上面表示的不一样,这里我就先不统一符号了,毕竟参考的资料不一样, 如果真懂了运算原理, 就不会在乎符号的问题, 并且这里主要也是说明计算原理,上面这个图是取自吴恩达老师的深度学习课程, 这里的RNN-cell, 可以理解成那个隐藏层, 里面当然很多个隐藏单元, 我们可以看一下这里面的整体计算:
这里与上面不同的是,指明了具体的激活函数
g
,
f
g, f
g,f了,这个公式和上面循环神经网络的计算公式一样, 无非是符号换了一下。
下面我们可以基于上面的这个RNN的运算过程, 用numpy简单的写一下。为了看清楚这个过程, 还找了张动图:
动图后面会给出参考链接。
4. RNN前向传播的numpy实现与RNN的局限性
根据上面的图, 我们就用numpy代码简单实现一下RNN的前向传播,这样更容易里面RNN的前向传播过程, 首先,依然是先定义上面细节中的三个名词:timesteps, input_dim和units, 这里我们假设时间步长是4, input_dim是3, units是5, 然后10个样本。 实现过程,我们先看看一个RNN-cell里面的计算, 把上面的图拿下来:
先实现一个Cell里面的计算过程, 我们可以先看一下这里面的输入有当前时间步的输入数据xt,
前一时间步的输入数据a_prev
, 然后输出有a_t
, yt_pred
, 而计算公式就是上面那个,参数有
W
a
x
W_{ax}
Wax, 维度是(5, 3), 这个根据input_dim和units确定的,因为它描述的是输入和隐藏单元之间的一种映射,
W
a
a
W_{aa}
Waa, 维度是(5, 5), 这个是units确定, 因为它描述的是隐藏单元与下一个时间步隐藏单元的映射,
W
y
a
W_{ya}
Wya, 维度是(2, 5),描述的是输出与隐藏单元的映射, 所以可以直接定义一个函数, 写这个计算过程:
def rnn_cell_forward(xt, a_prev, parameters):
# 获得参数
Wax = parameters["Wax"]
Waa = parameters["Waa"]
Wya = parameters["Wya"]
ba = parameters["ba"]
by = parameters["by"]
# cell 的前向传播
a_t = np.tanh(np.dot(Wax, xt) + np.dot(Waa, a_prev) + ba)
yt_pred = softmax(np.dot(Wya, a_t) + by)
# 保存一下重要结果
cache = (a_t, a_prev, xt, parameters)
return a_t, yt_pred, cache
xt = np.random.randn(3,10)
a_prev = np.random.randn(5,10)
# 初始化参数
Waa = np.random.randn(5, 5)
Wax = np.random.randn(5, 3)
Wya = np.random.randn(2, 5)
ba = np.random.randn(5, 1)
by = np.random.randn(2, 1)
a_next, yt_pred, cache = rnn_cell_forward(xt, a_prev, parameters)
这就是cell的前向传播, 当然这里面有一些细节, 比如像那些参数, a_next, a_prev, xt这些东西, 最好都保存一下,反向传播的时候会用到。
一个cell的前向传播完毕, 那么整个RNN的前向传播应该咋写呢? 还是看图
有了一个cell的计算, 整个RNN其实就是时间步的一个循环, 所以可以用一个时间步循环解决这个问题, 还是先分析一下, 接收的数据是a0和整个x, 这个x的维度就是(input_dim, m, timesteps), 而a0的维度就是(units, m), 而这里的输出有最后一步的a, 这个维度是(units, m, timesteps), y_pred, 维度是(n_y, m, timesteps)。 过程就是遍历每个时间步, 得到本时间步的输出y和下一步的输入a_next, 把这个加入到最后的y和a里面即可。
def rnn_forward(x, a0, parameters):
caches = [] # 保存结果
# 获取到那几个重要的参数
input_dim, m, T_x = x.shape
n_y, units = parameters['Wya'].shape
# 初始化a, y_pred
a = np.zeros((units, m, T_x))
y_pred = np.zeros((n_y, m, T_x))
a_next = a0
for t in range(T_x):
a_next, yt_pred, cache = rnn_cell_forward(x[:,:,t], a_next, parameters)
a[:, :, t] = a_next
y_pred[:, :, t] = yt_pred
caches.append(cache)
caches = (caches, x)
return a, y_pred, caches
这就是RNN的前向传播过程, 这样理解这个循环神经网络的计算过程为啥是一气呵成了吧, 但是这里还要注意一下, 这个和普通的全连接前向传播的循环可不一样, 这里是只有一层隐藏层, 然后这里的循环是时间步的循环, 而全连接网络那里的循环是多个隐藏层, 循环是隐藏层的循环计算, 如果不理解的话很容易就搞乱了。这里是一层的RNN, 但是有一个时间步的循环计算, 而普通的一层全连接网络,是不用循环计算的。
那么这里又要看一个问题了, 我们知道全连接那分析的时候,如果层数很多, 就会出现梯度消失或者爆炸, 这是因为在反向传播的时候, 通过链式法则的推导,会用到上一层正向传播过程中的输出, 而这个输出,又依赖于前面层数的输出,这是一个连乘的计算过程, 所以如果前面某一层某个值很大或者很小的时候,就会导致后面某些层的输出很小, 这样就会导致梯度消失或者爆炸, 如果不知道我在说啥的,建议补一下基础, 或者看看系统学习Pytorch笔记六:模型的权值初始化与损失函数介绍, 这里面解释了一点梯度消失和爆炸现象。
而回到RNN, 其实也存在这个现象,为啥呢? 因为上面说了, 一层的RNN就会有一个时间步的循环计算, 而这个时间步的长度是依赖于输入序列的长度的, 如果序列很长很长,那么这里也相当于前向传播有了一个很深的连乘运算, 则RNN的反向传播过程会随着时间序列产生长期依赖,这是因为每一步的隐态
S
t
S_t
St随着时间序列在前向传播, 而
S
t
S_t
St又是
W
x
,
W
s
W_x, W_s
Wx,Ws的函数, 所以会有一个时间步之间隐态的一个连乘计算, 有连乘,就会出现危险, 如果不明白的话,看个计算过程就明白了, 顺便看一下RNN的反向传播:
就拿一个三个时间步的RNN来看, 通过上面的分析,我们可以写一下它的前向传播过程:
S
1
=
W
x
X
1
+
W
s
S
0
+
b
1
O
1
=
W
o
S
1
+
b
2
S
2
=
W
x
X
2
+
W
s
S
1
+
b
1
O
2
=
W
o
S
2
+
b
2
S
3
=
W
x
X
3
+
W
s
S
2
+
b
1
O
3
=
W
o
S
3
+
b
2
\begin{array}{l} S_{1}=W_{x} X_{1}+W_{s} S_{0}+b_{1} \qquad O_{1}=W_{o} S_{1}+b_{2} \\ S_{2}=W_{x} X_{2}+W_{s} S_{1}+b_{1}\qquad O_{2}=W_{o} S_{2}+b_{2} \\ S_{3}=W_{x} X_{3}+W_{s} S_{2}+b_{1} \qquad O_{3}=W_{o} S_{3}+b_{2} \end{array}
S1=WxX1+WsS0+b1O1=WoS1+b2S2=WxX2+WsS1+b1O2=WoS2+b2S3=WxX3+WsS2+b1O3=WoS3+b2
假设在t=3时刻, 损失函数为 L 3 = 1 2 ( Y 3 − O 3 ) 2 L_{3}=\frac{1}{2}\left(Y_{3}-O_{3}\right)^{2} L3=21(Y3−O3)2
则对于一次训练任务的损失函数为 L = ∑ t = 0 T L t L=\sum_{t=0}^{T} L_{t} L=∑t=0TLt, 就是每个时间步损失函数的一个累加。 那么我们开始考虑反向传播的过程, 其实就是对 W x , W s , W o , b 1 , b 2 W_x, W_s, W_o, b_1, b_2 Wx,Ws,Wo,b1,b2求偏导, 并不断调整它们使L尽可能达到最小。
那么我们就对t3时刻的
W
x
,
W
s
,
W
o
W_x, W_s, W_o
Wx,Ws,Wo求一下偏导:
∂
L
3
∂
W
o
=
∂
L
3
∂
O
3
∂
O
3
∂
W
o
=
(
O
3
−
Y
3
)
S
3
\frac{\partial L_{3}}{\partial W_{o}}=\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial W_{o}}=(O_3-Y_3) S_3
∂Wo∂L3=∂O3∂L3∂Wo∂O3=(O3−Y3)S3
这个我们发现,对于
W
o
W_o
Wo求导, 并没有产生长期依赖。而下面看看对于
W
x
,
W
s
W_x, W_s
Wx,Ws求偏导:
∂
L
3
∂
W
x
=
∂
L
3
∂
O
3
∂
O
3
∂
S
3
∂
S
3
∂
W
x
+
∂
L
3
∂
O
3
∂
O
3
∂
S
3
∂
S
3
∂
S
2
∂
S
2
∂
W
x
+
∂
L
3
∂
O
3
∂
O
3
∂
S
3
∂
S
3
∂
S
2
∂
S
2
∂
S
1
∂
S
1
∂
W
x
∂
L
3
∂
W
s
=
∂
L
3
∂
O
3
∂
O
3
∂
S
3
∂
S
3
∂
W
s
+
∂
L
3
∂
O
3
∂
O
3
∂
S
3
∂
S
3
∂
S
2
∂
S
2
∂
W
s
+
∂
L
3
∂
O
3
∂
O
3
∂
S
3
∂
S
3
∂
S
2
∂
S
2
∂
S
1
∂
S
1
∂
W
s
\begin{aligned} \frac{\partial L_{3}}{\partial W_{x}}=\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial W_{x}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial S_{2}} \frac{\partial S_{2}}{\partial W_{x}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial S_{2}} \frac{\partial S_{2}}{\partial S_{1}} \frac{\partial S_{1}}{\partial W_{x}} \\ \\ \frac{\partial L_{3}}{\partial W_{s}}=\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial S_{2}} \frac{\partial S_{2}}{\partial W_{s}}+\frac{\partial L_{3}}{\partial O_{3}} \frac{\partial O_{3}}{\partial S_{3}} \frac{\partial S_{3}}{\partial S_{2}} \frac{\partial S_{2}}{\partial S_{1}} \frac{\partial S_{1}}{\partial W_{s}} \end{aligned}
∂Wx∂L3=∂O3∂L3∂S3∂O3∂Wx∂S3+∂O3∂L3∂S3∂O3∂S2∂S3∂Wx∂S2+∂O3∂L3∂S3∂O3∂S2∂S3∂S1∂S2∂Wx∂S1∂Ws∂L3=∂O3∂L3∂S3∂O3∂Ws∂S3+∂O3∂L3∂S3∂O3∂S2∂S3∂Ws∂S2+∂O3∂L3∂S3∂O3∂S2∂S3∂S1∂S2∂Ws∂S1
这两个就会产生一种时间序列依赖,因为
S
t
S_t
St随着时间序列在前向传播, 而
S
t
S_t
St又是
W
x
,
W
s
W_x, W_s
Wx,Ws的函数。 根据上面求偏导的过程, 可以得到任意一个时刻对
W
x
,
W
s
W_x, W_s
Wx,Ws求偏导的公式:
∂
L
t
∂
W
x
=
∑
k
=
0
t
∂
L
t
∂
O
t
∂
O
t
∂
S
t
(
∏
j
=
k
−
1
t
∂
S
j
∂
S
j
−
1
)
∂
S
k
∂
W
x
\frac{\partial L_{t}}{\partial W_{x}}=\sum_{k=0}^{t} \frac{\partial L_{t}}{\partial O_{t}} \frac{\partial O_{t}}{\partial S_{t}}\left(\prod_{j=k-1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}\right) \frac{\partial S_{k}}{\partial W_{x}}
∂Wx∂Lt=k=0∑t∂Ot∂Lt∂St∂Ot⎝⎛j=k−1∏t∂Sj−1∂Sj⎠⎞∂Wx∂Sk
对
W
s
W_s
Ws求偏导也是同理。 如果加上激活函数,
S
j
=
tanh
(
W
x
X
j
+
W
s
S
j
−
1
+
b
1
)
S_{j}=\tanh \left(W_{x} X_{j}+W_{s} S_{j-1}+b_{1}\right)
Sj=tanh(WxXj+WsSj−1+b1), 则
∏
j
=
k
+
1
t
∂
S
j
∂
S
j
−
1
=
∏
j
=
k
+
1
t
tanh
′
W
s
\prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}}=\prod_{j=k+1}^{t} \tanh ^{\prime} W_{s}
j=k+1∏t∂Sj−1∂Sj=j=k+1∏ttanh′Ws
这样就清晰了, tanh的导数我们知道是小于等于1的, 而这个连乘的大小其实取决于这个 W s W_s Ws, 如果 W s W_s Ws很大, 那么在求导过程中就会梯度爆炸, 如果很小,那么就会出现梯度消失, 所以这就是RNN中梯度消失或者爆炸的原因, 关键之处就是这个连乘运算。 注意一下, 之类说的梯度消失,并不是说后面时刻参数更新时梯度为0, 而是说后面时刻参数更新的时候, 越往前的序列信息对更新起不到作用了。
并且我们假设有个t=20的时候看看
W
x
,
W
s
W_x, W_s
Wx,Ws求偏导的公式:
∂
L
20
∂
W
x
=
∂
L
20
∂
O
20
∂
O
20
∂
S
20
∂
S
20
∂
W
x
+
∂
L
20
∂
O
20
∂
O
20
∂
S
20
∂
S
20
∂
S
19
∂
S
19
∂
W
x
+
∂
L
20
∂
O
20
∂
O
20
∂
S
20
∂
S
20
∂
S
19
∂
S
19
∂
S
18
∂
S
18
∂
W
x
+
.
.
.
.
+
0
+
0
+
.
.
.
+
0
\begin{aligned} \frac{\partial L_{20}}{\partial W_{x}}=\frac{\partial L_{20}}{\partial O_{20}} \frac{\partial O_{20}}{\partial S_{20}} \frac{\partial S_{20}}{\partial W_{x}}+\frac{\partial L_{20}}{\partial O_{20}} \frac{\partial O_{20}}{\partial S_{20}} \frac{\partial S_{20}}{\partial S_{19}} \frac{\partial S_{19}}{\partial W_{x}}+\frac{\partial L_{20}}{\partial O_{20}} \frac{\partial O_{20}}{\partial S_{20}} \frac{\partial S_{20}}{\partial S_{19}} \frac{\partial S_{19}}{\partial S_{18}} \frac{\partial S_{18}}{\partial W_{x}} + ....+ 0 + 0 + ...+ 0\end{aligned}
∂Wx∂L20=∂O20∂L20∂S20∂O20∂Wx∂S20+∂O20∂L20∂S20∂O20∂S19∂S20∂Wx∂S19+∂O20∂L20∂S20∂O20∂S19∂S20∂S18∂S19∂Wx∂S18+....+0+0+...+0
而
∂
S
k
∂
W
x
=
X
k
\frac{\partial S_k}{\partial W_x}=X_k
∂Wx∂Sk=Xk, 这个其实也就是再说,如果某时刻距离当前时刻越远, 比如t=3, 也就是上面加法的后面一长串累乘到出现
S
3
S_3
S3的时候,因为有了这一长串累乘,很容易导致梯度消失,那么
∂
S
3
∂
W
x
=
X
3
\frac{\partial S_3}{\partial W_x}=X_3
∂Wx∂S3=X3不起作用了(因为累乘那块是0, 乘以这个梯度也是0), 这也就是说在t=20的时候,t=3时刻的输入对于t=20时参数更新
L
20
W
x
\frac{L_{20}}{W_x}
WxL20是起不到任何作用的。 这就相当于RNN并没有办法捕捉这种长期的依赖关系, 只能捕捉局部的依赖关系, 比如t=20时参数的更新,可能只依赖于
X
19
,
X
18
,
X
17
X_{19},X_{18}, X_{17}
X19,X18,X17这3步的输入值。
这对应着吴恩达老师讲的那个例子: The cat, which ate already, …, was full。 就是后面的was还是were, 要看前面是cat, 还是cats, 但是一旦中间的这个which 句子很长, cat的信息根本传不到was这里来。对was的更新没有任何帮助, 这是RNN一个很大的不足之处。
所以,通过上面的分析, 我们知道了RNN存在着一个很大的问题梯度消失,而RNN出现梯度消失问题之后, 就没法再捕捉序列之间的长期关联或者依赖关系。
而解决上面这个问题的根本,其实就是 ∏ j = k + 1 t ∂ S j ∂ S j − 1 \prod_{j=k+1}^{t} \frac{\partial S_{j}}{\partial S_{j-1}} ∏j=k+1t∂Sj−1∂Sj, 因为这个连乘, 才会有梯度消失或者爆炸现象,进而才会无法捕捉长期依赖。 那么如何解决这个问题呢? 那就是让这个连乘保持一个常量, 这样的话就不会梯度消失或者爆炸了。 当然RNN是做不到了, 所以LSTM就诞生了。
5. 总结
好了, 这篇基础知识的内容就整理到这里, 如果后面加上LSTM就会太多了, 所以趁热乎快速回顾一下: 这篇文章就是围绕着时序序列的任务进行展开, 从全连接网络开始,复习了一下DNN的步骤和处理这种时序序列任务的局限性, 引出了RNN, 然后重点说了一下RNN的运算原理和几个细节部分, 纠正一下初学者对RNN的理解误差, 然后为了更加详细的理解RNN的计算原理,用numpy实现了一下前向传播的过程, 并有一个例子写了一下反向传播的公式, 并解释了一下为什么RNN会存在梯度消失和爆炸现象, 为什么不能捕捉长期依赖关系, 最后又分析了这两个问题的解决关键在什么地方。
而RNN的这两个问题到底是如何解决的呢? 下一篇重温LSTM及其变体GRU中告诉你 😉
参考: