RNN基本概述
我们首先看一下百度百科对于RNN的解释:
循环神经网络(Recurrent Neural Network, RNN)是一类以序列(sequence)数据为输入,在序列的演进方向进行递归(recursion)且所有节点(循环单元)按链式连接的递归神经网络(recursive neural network)。
对循环神经网络的研究始于二十世纪80-90年代,并在二十一世纪初发展为深度学习(deep learning)算法之一,其中双向循环神经网络(Bidirectional RNN, Bi-RNN)和长短期记忆网络(Long Short-Term Memory networks,LSTM)是常见的的循环神经网络。
循环神经网络具有记忆性、参数共享并且图灵完备(Turing completeness),因此在对序列的非线性特征进行学习时具有一定优势。循环神经网络在自然语言处理(Natural Language Processing, NLP),例如语音识别、语言建模、机器翻译等领域有应用,也被用于各类时间序列预报。引入了卷积神经网络(Convoutional Neural Network,CNN)构筑的循环神经网络可以处理包含序列输入的计算机视觉问题。
RNN的优势及结构形式
我们用CNN的网络结果进行比较:
观察CNN的结构可以发现,神经网络通常包括输入层、隐藏层、输出层。正向传播的过程为:1输入数据、2设置权重连接层与层、3通过激活函数控制隐含层的输出(对于多层结构,这里的输出相当于下一层的输入)、4比较输出层结果与真实标签的差值,通过反向传播算法更新权重值。在这个过程中,激活函数是我们事先设定好的,权重值是我们学习过程中不断迭代的,也就是我们要学习的目标。所谓神经学习就是学习这些权重值。观察整个流程还可以发现一个很重要的一点,输入的数据调整顺序并不会影响结果(调换
x
1
x_{1}
x1 和
x
2
x_{2}
x2的顺序对结果没有影响)。
调换输入的参数顺序对结果不产生影响这个特点很多时候是不可行的,比如我们做机器翻译问题、语义识别问题等,这些都是要结合上下文决定我们的预测结果。因此我们的网络结构要时刻关注之前甚至之后的输入信息。
RNN很好的解决了对前面输入信息的一个兼容问题,如下图所示:
上图左边是RNN的基本结构,右边是它的展开形式。
x
x
x是某一个时刻的输入,
s
s
s是对应时刻的记忆(相当于隐藏层),它捕捉了之前时间点上的信息,
o
o
o是该时刻对应的输出,它由当前时刻及之前所有的‘记忆’共同计算得到。
这里的**
U
、
V
、
W
U、V、W
U、V、W**在整个神经网络上都共享一组参数(整个神经网络共用同一个
U
U
U,同一个
V
V
V和同一个
W
W
W),这样做极大的减少了需要训练和预估的参数量。
每一个输入值都只与它本身的那条路线建立权连接,不会和别的神经元连接。
RNN的前向传播过程
现在我们知道了RNN的基本结构,我们看一下关于RNN的前向传播算法,我们以 t t t时刻为例:
o
t
=
σ
(
V
s
t
)
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
1
o^{t}=\sigma (Vs^{t})..............................................................................................................1
ot=σ(Vst)..............................................................................................................1
s
t
=
ϕ
(
U
x
t
+
W
s
t
−
1
+
b
)
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
2
s^{t}=\phi (Ux^{t}+Ws^{t-1}+b)..........................................................................................2
st=ϕ(Uxt+Wst−1+b)..........................................................................................2
上面两个公式中
σ
,
ϕ
\sigma,\phi
σ,ϕ都是激活函数,我们把2带入1得到
o
t
=
σ
V
(
ϕ
(
U
x
t
+
W
s
t
−
1
+
b
)
)
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
3
o^{t}=\sigma V(\phi (Ux^{t}+Ws^{t-1}+b))..................................................................................3
ot=σV(ϕ(Uxt+Wst−1+b))..................................................................................3
o
t
=
σ
V
(
ϕ
(
U
x
t
+
W
(
ϕ
(
U
x
t
−
1
+
W
s
t
−
2
+
b
)
)
+
b
)
)
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
4
o^{t}=\sigma V(\phi (Ux^{t}+W(\phi (Ux^{t-1}+Ws^{t-2}+b))+b))...................................................4
ot=σV(ϕ(Uxt+W(ϕ(Uxt−1+Wst−2+b))+b))...................................................4
o
t
=
σ
V
(
ϕ
(
U
x
t
+
W
(
ϕ
(
U
x
t
−
1
+
W
(
ϕ
(
U
x
t
−
2
+
W
s
t
−
3
+
b
)
)
+
b
)
)
+
b
)
)
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
5
o^{t}=\sigma V(\phi (Ux^{t}+W(\phi (Ux^{t-1}+W(\phi (Ux^{t-2}+Ws^{t-3}+b))+b))+b))....................5
ot=σV(ϕ(Uxt+W(ϕ(Uxt−1+W(ϕ(Uxt−2+Wst−3+b))+b))+b))....................5
.
.
.
.
.
.
.
.
.
.
.
...........
...........
时间步长越大,
s
t
s^{t}
st所要保存的信息越多。
BPTT算法
了解了RNN的正向传播过程,我们现在研究一下他的训练方法。毕竟训练一个网络才是我们的目的。RNN常用的训练方法是BPTT算法(back-propagation through time),通过时间序列的反向传播,其实本质还是BP算法,只不过RNN处理时间序列数据,所以要基于时间反向传播,故叫随时间反向传播。BPTT的中心思想和BP算法相同,沿着需要优化的参数的负梯度方向不断寻找更优的点直至收敛。综上所述,BPTT算法本质还是BP算法,BP算法本质还是梯度下降法,那么求各个参数的梯度便成了此算法的核心。
还是以上述过程为例,我们需要优化的参数为
U
,
V
,
W
U,V,W
U,V,W。
根据上面的前向传播过程,我们可以发现,这三个参数中
V
V
V仅需关注当前,相较于
U
,
W
U,W
U,W,我们先来求
V
V
V的偏导数。
∂ L t ∂ V = ∂ L t ∂ o t ∗ ∂ o t ∂ V \frac{\partial L^{t}}{\partial V}=\frac{\partial L^{t}}{\partial o^{t}}*\frac{\partial o^{t}}{\partial V} ∂V∂Lt=∂ot∂Lt∗∂V∂ot
上述公式仅求得了 t t t时刻的偏导数,由于RNN的损失也是随着时间累加,因此更普适的公式如下:
L = ∑ t = 1 n L t L=\sum_{t=1}^{n}L^{t} L=∑t=1nLt
∂ L ∂ V = ∑ t = 1 n ∂ L t ∂ o t ∗ ∂ o t ∂ V \frac{\partial L}{\partial V}=\sum_{t=1}^{n}\frac{\partial L^{t}}{\partial o^{t}}*\frac{\partial o^{t}}{\partial V} ∂V∂L=∑t=1n∂ot∂Lt∗∂V∂ot
W W W和 U U U的偏导数涉及到历史数据(如上式,我们在不断的分解 s s s),偏导数要不断的嵌入链式法则,我们先假设时刻为3,那么在第三个时刻 L L L对 W W W的偏导数为:
∂ L 3 ∂ W = ∂ L 3 ∂ o 3 ∗ ∂ o 3 ∂ s 3 ∗ ∂ s 3 ∂ W + ∂ L 3 ∂ o 3 ∗ ∂ o 3 ∂ s 3 ∗ ∂ s 3 ∂ s 2 ∗ ∂ s 2 ∂ W + ∂ L 3 ∂ o 3 ∗ ∂ o 3 ∂ s 3 ∗ ∂ s 3 ∂ s 2 ∗ ∂ s 2 ∂ s 1 ∗ ∂ s 1 ∂ W \frac{\partial L^{3}}{\partial W}=\frac{\partial L^{3}}{\partial o^{3}}*\frac{\partial o^{3}}{\partial s^{3}}*\frac{\partial s^{3}}{\partial W}+\frac{\partial L^{3}}{\partial o^{3}}*\frac{\partial o^{3}}{\partial s^{3}}*\frac{\partial s^{3}}{\partial s^{2}}*\frac{\partial s^{2}}{\partial W}+\frac{\partial L^{3}}{\partial o^{3}}*\frac{\partial o^{3}}{\partial s^{3}}*\frac{\partial s^{3}}{\partial s^{2}}*\frac{\partial s^{2}}{\partial s^{1}}*\frac{\partial s^{1}}{\partial W} ∂W∂L3=∂o3∂L3∗∂s3∂o3∗∂W∂s3+∂o3∂L3∗∂s3∂o3∗∂s2∂s3∗∂W∂s2+∂o3∂L3∗∂s3∂o3∗∂s2∂s3∗∂s1∂s2∗∂W∂s1
相应的,该时刻 L L L对 U U U的偏导数为
∂ L 3 ∂ U = ∂ L 3 ∂ o 3 ∗ ∂ o 3 ∂ s 3 ∗ ∂ s 3 ∂ U + ∂ L 3 ∂ o 3 ∗ ∂ o 3 ∂ s 3 ∗ ∂ s 3 ∂ s 2 ∗ ∂ s 2 ∂ U + ∂ L 3 ∂ o 3 ∗ ∂ o 3 ∂ s 3 ∗ ∂ s 3 ∂ s 2 ∗ ∂ h 2 ∂ s 1 ∗ ∂ s 1 ∂ U \frac{\partial L^{3}}{\partial U}=\frac{\partial L^{3}}{\partial o^{3}}*\frac{\partial o^{3}}{\partial s^{3}}*\frac{\partial s^{3}}{\partial U}+\frac{\partial L^{3}}{\partial o^{3}}*\frac{\partial o^{3}}{\partial s^{3}}*\frac{\partial s^{3}}{\partial s^{2}}*\frac{\partial s^{2}}{\partial U}+\frac{\partial L^{3}}{\partial o^{3}}*\frac{\partial o^{3}}{\partial s^{3}}*\frac{\partial s^{3}}{\partial s^{2}}*\frac{\partial h^{2}}{\partial s^{1}}*\frac{\partial s^{1}}{\partial U} ∂U∂L3=∂o3∂L3∗∂s3∂o3∗∂U∂s3+∂o3∂L3∗∂s3∂o3∗∂s2∂s3∗∂U∂s2+∂o3∂L3∗∂s3∂o3∗∂s2∂s3∗∂s1∂h2∗∂U∂s1
我们把 s s s之前的相同部分做合并,简化上述公式得到
∂ L t ∂ W = ∑ k = 0 t ∂ L t ∂ o t ∗ ∂ o t ∂ s t ( ∏ j = k + 1 t ∂ s j ∂ s j − 1 ) ∗ ∂ s k ∂ W \frac{\partial L^{t}}{\partial W}=\sum_{k=0}^{t}\frac{\partial L^{t}}{\partial o^{t}}*\frac{\partial o^{t}}{\partial s^{t}}(\prod_{j=k+1}^{t}\frac{\partial s^{j}}{\partial s^{j-1}})*\frac{\partial s^{k}}{\partial W} ∂W∂Lt=∑k=0t∂ot∂Lt∗∂st∂ot(∏j=k+1t∂sj−1∂sj)∗∂W∂sk
∂ L t ∂ U = ∑ k = 0 t ∂ L t ∂ o t ∗ ∂ o t ∂ s t ( ∏ j = k + 1 t ∂ s j ∂ s j − 1 ) ∗ ∂ s k ∂ U \frac{\partial L^{t}}{\partial U}=\sum_{k=0}^{t}\frac{\partial L^{t}}{\partial o^{t}}*\frac{\partial o^{t}}{\partial s^{t}}(\prod_{j=k+1}^{t}\frac{\partial s^{j}}{\partial s^{j-1}})*\frac{\partial s^{k}}{\partial U} ∂U∂Lt=∑k=0t∂ot∂Lt∗∂st∂ot(∏j=k+1t∂sj−1∂sj)∗∂U∂sk
现在我们将激活函数带入上述公式中间的累乘部分:
∏ j = k + 1 t ∂ s j ∂ s j − 1 = ∏ j = k + 1 t t a n h ′ ∗ W s \prod_{j=k+1}^{t}\frac{\partial s^{j}}{\partial s^{j-1}}=\prod_{j=k+1}^{t}tanh^{'}*W_{s} ∏j=k+1t∂sj−1∂sj=∏j=k+1ttanh′∗Ws
或者
∏ j = k + 1 t ∂ s j ∂ s j − 1 = ∏ j = k + 1 t s i g m o i d ′ ∗ W s \prod_{j=k+1}^{t}\frac{\partial s^{j}}{\partial s^{j-1}}=\prod_{j=k+1}^{t}sigmoid^{'}*W_{s} ∏j=k+1t∂sj−1∂sj=∏j=k+1tsigmoid′∗Ws
我们会发现累乘会导致激活函数导数的累乘,进而会导致“梯度消失“和“梯度爆炸“现象的发生。
(激活函数导数范围0-1之间,不断累乘,最终趋于0)
LSTM
上面提到的梯度消失或者梯度爆炸问题会导致一个最直接的结果就是,随着训练的进行,前期的参数对后续的决策影响会越来越小直至没有影响,相当于网络忘记了最初的数据。这对于需要根据上下文去预测下一步会发生什么的网络结构来说几乎是致命的。因此出现了改进型的RNN结构,LSTM和GRU。
我们先来看下面一段例子(例子来自博客RNN,后续的一些理论解释来自百度文库资料)
有时候,我们仅仅需要知道先前的信息来执行当前的任务。例如,我们有一个语言模型用来基于先前的词来预测下一个词。如果我们试着预测 “the clouds are in the sky” 最后的词,我们并不需要任何其他的上下文 —— 因此下一个词很显然就应该是 sky。在这样的场景中,相关的信息和预测的词位置之间的间隔是非常小的,RNN 可以学会使用先前的信息。
但是同样会有一些更加复杂的场景。假设我们试着去预测“I grew up in France… I speak fluent French”最后的词。当前的信息建议下一个词可能是一种语言的名字,但是如果我们需要弄清楚是什么语言,我们是需要先前提到的离当前位置很远的 France 的上下文的。这说明相关信息和当前预测位置之间的间隔就肯定变得相当的大。
不幸的是,在这个间隔不断增大时,RNN 会丧失学习到连接如此远的信息的能力。
为何RNN做不到的事情,LSTM可以做到?我们先来比较一下这两个网络结构的不同,下面我们先看一下RNN的结构形式:
这里我们可以清楚的看到,重复模块里面只有一个简单的结构来连接上下层之间的数据传递,例如
t
a
n
h
tanh
tanh激活函数,通过上面的公式推导我们知道,随着网络的加深,由于梯度消失的原因,较远的信息网络会遗忘。
下面我们来看一下LSTM的网络结构。
我们可以看到,LSTM的结构与RNN相似却不同。除了h随着时间流动之外,细胞状态C也在随时间流动,而细胞状态的流动有点类似于生产线上的传送带流动一样,是直接在整个链上运行,信息在上面保持不变很容易。
关于LSTM的几个名词解释:
细胞状态:细胞状态类似于传送带。直接在整个链上运行,只有一些少量的线性交互。信息在上面流传保持不变。
控制细胞状态方法:
1通过‘门’让信息选择性通过,来去除或者增加信息到细胞状态
2包含一个sigmoid神经网络层和一个pointwise乘法操作
3Sigmoid层输出0到1之间的概率值,描述每个部分有多少量可以通过。0代表‘不允许任何量通过’,1代表‘允许任意量通过’
详细理解LSTM的计算过程
在我们 LSTM 中的第一步是决定我们会从细胞状态中丢弃什么信息。这个决定通过一个称为遗忘门完成。该门会读取
h
t
−
1
h_{t-1}
ht−1和
x
t
x_t
xt,输出一个在 0 到 1 之间的数值给每个在细胞状态
C
t
−
1
C_{t-1}
Ct−1中的数字。1 表示“完全保留”,0 表示“完全舍弃”。
让我们回到语言模型的例子中来基于已经看到的预测下一个词。在这个问题中,细胞状态可能包含当前主语的性别,因此正确的代词可以被选择出来。当我们看到新的主语,我们希望忘记旧的主语。
这里激活函数的一个作用就是决定这个‘门’遗忘的信息量(输出的值为0-1,因此决定了记住百分之多少的内容)
接受新的参数是因为新的输入对结果肯定产生影响。
下一步是确定什么样的新信息被存放在细胞状态中。这里包含两个部分。第一,sigmoid 层称 “输入门层” 决定什么值我们将要更新。然后,一个 tanh 层创建一个新的候选值向量, C ~ t \tilde{C}_{t} C~t会被加入到状态中。
在我们语言模型的例子中,我们希望增加新的主语的性别到细胞状态中,来替代旧的需要忘记的主语。
这一步的目的是更新细胞状态,将 C t − 1 C_{t-1} Ct−1更新为 C t C_t Ct。做法如下:
我们把旧状态与
f
t
f_t
ft 相乘,丢弃掉我们确定需要丢弃的信息。接着加上
i
t
∗
C
~
t
i_t * \tilde{C}_t
it∗C~t .这就是新的候选值,根据我们决定更新每个状态的程度进行变化。
有了上面的理解基础输入门,输出门理解起来就简单多了。sigmoid函数选择更新内容,tanh函数创建更新候选。
最终,我们需要确定输出什么值。这个输出将会基于我们的细胞状态,但是也是一个过滤后的版本。首先,我们运行一个 sigmoid 层来确定细胞状态的哪个部分将输出出去。接着,我们把细胞状态通过 tanh 进行处理(得到一个在 -1 到 1 之间的值)并将它和 sigmoid 门的输出相乘,最终我们仅仅会输出我们确定输出的那部分。
这三个门虽然功能上不同,但在执行任务的操作上是相同的。他们都是使用sigmoid函数作为选择工具,tanh函数作为变换工具,这两个函数结合起来实现三个门的功能。
LSTM三个门的总结:
1,细胞状态是核心,起到了承载‘记忆’的作用
2,前两个门是对细胞状态在传送带上的状态更改,遗忘一部分内容,增加一部分内容。并且这两个门得到了细胞状态的输出。
3,最终的输出是建立在细胞状态这一轮输出的基础上的。
GRU:LSTM的变体
GRU是2014年提出的一种LSTM改进算法. 它将忘记门和输入门合并成为一个单一的更新门, 同时合并了数据单元状态和隐藏状态, 使得模型结构比之于LSTM更为简单.
变体1
我们让门层也接受细胞状态的输入:
变体2
通过使用 coupled 忘记和输入门。不同于之前是分开确定什么忘记和需要添加什么新的信息,这里是一同做出决定。我们仅仅会当我们将要输入在当前位置时忘记。我们仅仅输入新的值到那些我们已经忘记旧的信息的那些状态 。
变体3
将忘记门和输入门合成了一个单一的 更新门。同样还混合了细胞状态和隐藏状态,和其他一些改动。最终的模型比标准的 LSTM 模型要简单,也是非常流行的变体。
双向RNN
有些情况下,当前的输出不只依赖于之前的序列元素,还可能依赖之后的序列元素; 比如做完形填空,机器翻译等应用。这时候我们需要网络结构可以兼顾前后输入信息: