RNN的反向传播及RNN的梯度消失与梯度爆炸

这篇文章讲一下RNN的反向传播算法BPTT,及RNN梯度消失和梯度爆炸的原因。

BPTT

RNN的反向传播,也称为基于时间的反向传播算法BPTT(back propagation through time)。对所有参数求损失函数的偏导,并不断调整这些参数使得损失函数变得尽可能小。

先贴出RNN的结构图以供观赏,下面讲的都是图中的单层单向RNN:

反向传播做的是:算出参数梯度更新参数。参数梯度=损失函数对参数的偏导数\frac{\partial L}{\partial w},参数更新公式已知:w=w-\alpha \frac{\partial L}{\partial w}\alpha是学习率。所以我们只需要求出 \frac{\partial L}{\partial w}

分两部分:1.定义损失函数 L2.再求出损失函数对参数的偏导数 \frac{\partial L}{\partial w},下面开始啦,以参数w为例子。

1.定义损失函数:

假设时刻 tt 的损失函数为: 

损失函数是均方差也好交叉熵也好都无所谓,这里只是举个例子,假定它是均方差。

因为有多个时刻,所以总损失函数为所有时刻的损失函数之和

   (1)

第一步完成,损失函数get!

2.求损失函数对参数的偏导数:

w在每一时刻都出现了,所以w在时刻 t 的梯度=时刻 t 的损失函数对所有时刻w的梯度

  (2)

将(2)代入(1)可得下面的结果, w的总梯度 \frac{\partial L}{\partial w}等于w所有时刻的梯度和:

第二步完成,梯度get!

3.更新参数

有了梯度就可以更新参数了:

以上三步就是针对参数 的一次反向传播,个人认为了解这些就可以了。

梯度消失

下面是RNN前向传播公式,其中f是激活 函数,g 一般是 softmax函数。

假设我们有一个RNN,时间序列只有三个时刻,下面是其结构图:

preview

前向传播:

反向传播:
我们现在只对t=3时刻的 U、V、W求损失函数 L3的偏导(其他时刻类似):

t = 3 时刻加上之前的时刻,一共是3,等于上下两个式子的加数。若 t 足够大,则式中的加数就会很多,红色部分的项数也越多。

根据上述公式,我们可以得出任意时刻 t 时对 U、V、W求偏导得公式,以  W 为例。

激活函数tanh和它的导数图像如下

可以看出 tanh‘  \leq 1 ,训练过程中几乎都是小于1的,而W 的值一般会处于0~1之间,当时间序列足够长,即t足够大时,

就会趋近于0,这就造成了梯度消失;当W值很大(一般为初始化不当引起)时,{

就会趋近于无穷,这就造成了梯度爆炸。

RNN梯度消失问题很常见,题都爆炸问题一般不常见。RNN中的梯度消失会造成什么后果呢?会使RNN的长时记忆失效,简而言之就是会忘记很久之前的信息,记性不好。

至于怎么避免RNN的梯度消失和梯度都爆炸,我们现在知道,造成这种现象的根本原因就在于

这个连乘式,我们可以使这个连乘式中每项的偏导

这就是LSTM做的事情了。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

&刘仔很忙

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值