这篇文章讲一下RNN的反向传播算法BPTT,及RNN梯度消失和梯度爆炸的原因。
BPTT
RNN的反向传播,也称为基于时间的反向传播算法BPTT(back propagation through time)。对所有参数求损失函数的偏导,并不断调整这些参数使得损失函数变得尽可能小。
先贴出RNN的结构图以供观赏,下面讲的都是图中的单层单向RNN:
反向传播做的是:算出参数梯度并更新参数。参数梯度=损失函数对参数的偏导数,参数更新公式已知: 。是学习率。所以我们只需要求出 。
分两部分:1.定义损失函数 ,2.再求出损失函数对参数的偏导数 ,下面开始啦,以参数为例子。
1.定义损失函数:
假设时刻 tt 的损失函数为:
损失函数是均方差也好交叉熵也好都无所谓,这里只是举个例子,假定它是均方差。
因为有多个时刻,所以总损失函数为所有时刻的损失函数之和:
(1)
第一步完成,损失函数get!
2.求损失函数对参数的偏导数:
在每一时刻都出现了,所以在时刻 t 的梯度=时刻 t 的损失函数对所有时刻的的梯度和:
(2)
将(2)代入(1)可得下面的结果, w的总梯度 等于在所有时刻的梯度和:
第二步完成,梯度get!
3.更新参数
有了梯度就可以更新参数了:
以上三步就是针对参数 w 的一次反向传播,个人认为了解这些就可以了。
梯度消失
下面是RNN前向传播公式,其中是激活 函数, 一般是 softmax函数。
假设我们有一个RNN,时间序列只有三个时刻,下面是其结构图:
前向传播:
反向传播:
我们现在只对t=3时刻的 U、V、W求损失函数 L3的偏导(其他时刻类似):
t = 3 时刻加上之前的时刻,一共是3,等于上下两个式子的加数。若 t 足够大,则式中的加数就会很多,红色部分的项数也越多。
根据上述公式,我们可以得出任意时刻 t 时对 U、V、W求偏导得公式,以 W 为例。
激活函数tanh和它的导数图像如下
可以看出 tanh‘ 1 ,训练过程中几乎都是小于1的,而W 的值一般会处于0~1之间,当时间序列足够长,即t足够大时,
就会趋近于0,这就造成了梯度消失;当W值很大(一般为初始化不当引起)时,{
就会趋近于无穷,这就造成了梯度爆炸。
RNN梯度消失问题很常见,题都爆炸问题一般不常见。RNN中的梯度消失会造成什么后果呢?会使RNN的长时记忆失效,简而言之就是会忘记很久之前的信息,记性不好。
至于怎么避免RNN的梯度消失和梯度都爆炸,我们现在知道,造成这种现象的根本原因就在于
这个连乘式,我们可以使这个连乘式中每项的偏导
或
这就是LSTM做的事情了。