博客已迁至知乎, 这篇文章的链接:https://zhuanlan.zhihu.com/p/70868991
前言
上篇文章RNN详解已经介绍了RNN的结构和前向传播的计算公式,这篇文章讲一下RNN的反向传播算法BPTT,及RNN梯度消失和梯度爆炸的原因。
BPTT
RNN的反向传播,也称为基于时间的反向传播算法BPTT(back propagation through time)。对所有参数求损失函数的偏导,并不断调整这些参数使得损失函数变得尽可能小。
先贴出RNN的结构图以供观赏,下面讲的都是图中的单层单向RNN:
图片来自:Recurrent Neural Networks Tutorial, Part 1 – Introduction to RNNs
反向传播做的是:算出参数梯度并更新参数。参数梯度=损失函数对参数的偏导数 ∂ L ∂ w \frac{\partial L}{\partial w} ∂w∂L,参数更新公式已知: w = w − α ∂ L ∂ w w=w-\alpha \frac{\partial L}{\partial w} w=w−α∂w∂L, α \alpha α是学习率。所以我们只需要求出 ∂ L ∂ w \frac{\partial L}{\partial w} ∂w∂L就可以了。
分两部分:1.定义损失函数 L L L,2.再求出损失函数对参数的偏导数 ∂ L ∂ w \frac{\partial L}{\partial w} ∂w∂L ,下面开始啦,以参数 w w w 为例:
1.定义损失函数:
假设时刻 t t t 的损失函数为: L t = 1 2 ( Y 3 − O 3 ) 2 L_t=\frac{1}{2}(Y_3-O_3)^2 Lt=21(Y3−O3)2
损失函数是均方差也好交叉熵也好都无所谓,这里只是举个例子,假定它是均方差。
因为有多个时刻,所以总损失函数为所有时刻的损失函数之和:
L = ∑ t = 0 T L t ( 1 ) L = \sum_{t=0}^{T}L_t \ \ \ \ \ \ \ \ \ \ \ \ \ \ (1) L=t=0∑TLt (1)
第一步完成,损失函数get!
2.求损失函数对参数的偏导数:
w w w在每一时刻都出现了,所以 w w w 在时刻 t t t 的梯度=时刻 t t t 的损失函数对所有时刻的 w w w 的梯度和:
∂ L t ∂ w = ∑ s = 0 T ∂ L t ∂ w s ( 2 ) \frac{\partial L_t}{\partial w}=\sum_{s=0}^{T}\frac{\partial L_t}{\partial w_s}\ \ \ \ \ (2) ∂w∂Lt=s=0∑T∂ws∂Lt (2)
将(2)代入(1)可得下面的结果, w w w 的总梯度 ∂ L ∂ w \frac{\partial L}{\partial w} ∂w∂L等于 w w w在所有时刻的梯度和:
∂ L ∂ w = ∑ t = 0 T ∂ L t ∂ w = ∑ t = 0 T ∑ s = 0 T ∂ L t ∂ w s \begin{aligned} \frac{\partial L}{\partial w} &=\sum_{t=0}^{T}\frac{\partial L_t}{\partial w}\\ &=\sum_{t=0}^{T}\sum_{s=0}^{T}\frac{\partial L_t}{\partial w_s}\\ \end{aligned} ∂w∂L=t=0∑T∂w∂Lt=t=0∑Ts=0∑T∂w