理解BPTT及RNN的梯度消失与梯度爆炸

本文深入探讨了RNN的反向传播算法BPTT,解释了如何定义损失函数、求损失函数对参数的偏导数以及如何更新参数。同时,文章阐述了RNN中导致梯度消失和梯度爆炸的原因,指出长时记忆失效的问题,并提及LSTM是如何解决这一问题的。
摘要由CSDN通过智能技术生成

博客已迁至知乎, 这篇文章的链接:https://zhuanlan.zhihu.com/p/70868991

前言

上篇文章RNN详解已经介绍了RNN的结构和前向传播的计算公式,这篇文章讲一下RNN的反向传播算法BPTT,及RNN梯度消失和梯度爆炸的原因。


BPTT

RNN的反向传播,也称为基于时间的反向传播算法BPTT(back propagation through time)。对所有参数求损失函数的偏导,并不断调整这些参数使得损失函数变得尽可能小。

先贴出RNN的结构图以供观赏,下面讲的都是图中的单层单向RNN:

图片来自:Recurrent Neural Networks Tutorial, Part 1 – Introduction to RNNs

反向传播做的是:算出参数梯度更新参数。参数梯度=损失函数对参数的偏导数 ∂ L ∂ w \frac{\partial L}{\partial w} wL,参数更新公式已知: w = w − α ∂ L ∂ w w=w-\alpha \frac{\partial L}{\partial w} w=wαwL α \alpha α是学习率。所以我们只需要求出 ∂ L ∂ w \frac{\partial L}{\partial w} wL就可以了。

分两部分:1.定义损失函数 L L L2.再求出损失函数对参数的偏导数 ∂ L ∂ w \frac{\partial L}{\partial w} wL ,下面开始啦,以参数 w w w 为例:

1.定义损失函数:

假设时刻 t t t 的损失函数为: L t = 1 2 ( Y 3 − O 3 ) 2 L_t=\frac{1}{2}(Y_3-O_3)^2 Lt=21(Y3O3)2

损失函数是均方差也好交叉熵也好都无所谓,这里只是举个例子,假定它是均方差。

因为有多个时刻,所以总损失函数为所有时刻的损失函数之和
L = ∑ t = 0 T L t                ( 1 ) L = \sum_{t=0}^{T}L_t \ \ \ \ \ \ \ \ \ \ \ \ \ \ (1) L=t=0TLt              (1)

第一步完成,损失函数get!

2.求损失函数对参数的偏导数:

w w w在每一时刻都出现了,所以 w w w 在时刻 t t t 的梯度=时刻 t t t 的损失函数对所有时刻 w w w 的梯度
∂ L t ∂ w = ∑ s = 0 T ∂ L t ∂ w s       ( 2 ) \frac{\partial L_t}{\partial w}=\sum_{s=0}^{T}\frac{\partial L_t}{\partial w_s}\ \ \ \ \ (2) wLt=s=0TwsLt     (2)

将(2)代入(1)可得下面的结果, w w w 的总梯度 ∂ L ∂ w \frac{\partial L}{\partial w} wL等于 w w w所有时刻的梯度和:
∂ L ∂ w = ∑ t = 0 T ∂ L t ∂ w = ∑ t = 0 T ∑ s = 0 T ∂ L t ∂ w s \begin{aligned} \frac{\partial L}{\partial w} &=\sum_{t=0}^{T}\frac{\partial L_t}{\partial w}\\ &=\sum_{t=0}^{T}\sum_{s=0}^{T}\frac{\partial L_t}{\partial w_s}\\ \end{aligned} wL=t=0TwLt=t=0Ts=0Tw

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值