[rnn]BPTT_梯度消失/爆炸问题

http://www.wildml.com/2015/10/recurrent-neural-networks-tutorial-part-3-backpropagation-through-time-and-vanishing-gradients/

翻译:
https://zhuanlan.zhihu.com/p/22338087

随时间的反向传播(BPTT)

让我们先迅速回忆一下RNN的基本公式,注意到这里在符号上稍稍做了改变(o变成 y^ ),这只是为了和我参考的一些资料保持一致。
st=tanh(Uxt+Wst1)
y^t=softmax(Vst)

同样把损失值定义为交叉熵损失,如下:
Et(yt,y^t)=ytlog(y^t)
E(y,y^)=tEt(yt,y^t)=tytlogy^t
这里, yt 表示时刻t正确的词, y^t 是我们的预测。通常我们会把整个句子作为一个训练样本,所以总体错误是每一时刻的错误的加和。
这里写图片描述
我们的目标是计算错误值相对于参数U, V, W的梯度以及用随机梯度下降学习好的参数。就像我们要把所有错误相加一样,我们同样会把每一时刻针对每个训练样本的梯度值相加: EW=tEtW
为了计算梯度,我们使用链式求导法则,主要是用反向传播算法往后传播错误。下文使用 E3 作为例子,主要是为了描述方便。

E3V=E3y^3y^3V=E3y^3y^3z3z3V=(y^3y3)s3
上面 z3=Vs3 是向量的外积。如果你不理解上面的公式,不要担心,我在这里跳过了一些步骤,你可以自己尝试来计算这些梯度值。这里我想说明的一点是梯度值只依赖于当前时刻的结果 y^3,y3,s3 。根据这些,计算V的梯度就只剩下简单的矩阵乘积了。

但是对于梯度 E3W 情况就不同了,我们可以像上面一样写出链式法则。
E3W=E3y^3y^3s3s3W

注意到这里的 s3=tanh(Uxt+Ws2) 依赖于 s2 s2 依赖于W和 s1 ,等等。所以为了得到W的梯度,我们不能将 s2 看作常量。我们需要再次使用链式法则,得到的结果如下:
E3W=3k=0E3y^3y^3s3s3skskW

我们把每一时刻得到的梯度值加和,换句话说,W在计算输出的每一步中都使用了。我们需要通过将t=3时刻的梯度反向传播至t=0时刻。
这里写图片描述
注意到这里和我们在深度前向神经网络中使用的标准反向传播算法是一致的,关键不同在于我们把每一时刻针对W的不同梯度做了加和。在传统神经网络中,不需要在层之间共享参数,就不需要做任何加和。在我看来,BPTT是应用于展开的RNN上的标准反向传播的另一个名字。就像反向传播一样,你也可以定义一个反向传递的delta向量,例如, δ(3)2=E3z2=E3s3s3s2s2z2 ,其中 z2=Ux2+Ws1

这会让你明白为什么标准RNN很难训练:序列会变得很长,可能有20个词或更多,因而就需要反向传播很多层。实践中,很多人会把发现传播截断至几步。

梯度消失问题

在教程前一部分,我提到RNN很难学到长范围的依赖——相隔几步的词之间的交互。这是有问题的因为英语中句子的意思通常由相距不是很近的词来决定:“The man who wore a wig on his head went inside”。这个句子讲的是一个男人走了进去,而不是关于假发。但是普通的RNN不可能捕捉这样的信息。要理解为什么,让我们先仔细看一下上面计算的梯度:
E3W=3k=0E3y^3y^3s3s3sks3W
注意到 s3sk 也需要使用链式法则,例如, s3s1=s3s2s2s1 。注意到因为我们是用向量函数对向量求导数,结果是一个矩阵(称为Jacobian Matrix),矩阵元素是每个点的导数。我们可以把上面的梯度重写成:
E3W=3k=0E3y^3y^3s3(3j=k+1sjsj1)skW
可以证明上面的Jacobian矩阵的二范数(可以认为是一个绝对值)的上界是1。这很直观,因为激活函数tanh把所有制映射到-1和1之间,导数值得界限也是1:
这里写图片描述
你可以看到tanh和sigmoid函数在两端的梯度值都为0,接近于平行线。当这种情况出现时,我们就认为相应的神经元饱和了。它们的梯度为0使得前面层的梯度也为0。矩阵中存在比较小的值,多个矩阵相乘会使梯度值以指数级速度下降,最终在几步后完全消失。比较远的时刻的梯度值为0,这些时刻的状态对学习过程没有帮助,导致你无法学习到长距离依赖。消失梯度问题不仅出现在RNN中,同样也出现在深度前向神经网中。只是RNN通常比较深(例子中深度和句子长度一致),使得这个问题更加普遍。

很容易想到,依赖于我们的激活函数和网络参数,如果Jacobian矩阵中的值太大,会产生梯度爆炸而不是梯度消失问题。梯度消失比梯度爆炸受到了更多的关注有两方面的原因。其一,梯度爆炸容易发现,梯度值会变成NaN,导致程序崩溃。其二,用预定义的阈值裁剪梯度可以简单有效的解决梯度爆炸问题。梯度消失出现的时候不那么明显而且不好处理。

幸运的是,已经有一些方法解决了梯度消失问题。合适的初始化矩阵W可以减小梯度消失效应,正则化也能起作用。更好的方法是选择ReLU而不是sigmoid和tanh作为激活函数。ReLU的导数是常数值0或1,所以不可能会引起梯度消失。更通用的方案时采用长短项记忆(LSTM)或门限递归单元(GRU)结构。LSTM在1997年第一次提出,可能是目前在NLP上最普遍采用的模型。GRU,2014年第一次提出,是LSTM的简化版本。这两种RNN结构都是为了处理梯度消失问题而设计的,可以有效地学习到长距离依赖。


https://www.zhihu.com/question/34878706
LSTM只能避免RNN的梯度消失(gradient vanishing);梯度膨胀(gradient explosion)不是个严重的问题,一般靠裁剪后的优化算法即可解决,比如gradient clipping(如果梯度的范数大于某个给定值,将梯度同比收缩)。下面简单说说LSTM如何避免梯度消失.

  • 传统的RNN总是用“覆写”的方式计算状态 St=f(St1,xt) ,其中 f() 表示仿射变换外面在套一个Sigmoid, xt 表示输入序列在时刻 t 的值。根据求导的链式法则,这种形式直接导致梯度被表示为连成积的形式,以致于造成梯度消失——粗略的说,很多个小于1的项连乘就很快的逼近零。

  • 现代的RNN(包括但不限于使用LSTM单元的RNN)使用“累加”的形式计算状态:St=tτ=1ΔSτ,其中的 ΔSτ 显示依赖序列输入 xt . 稍加推导即可发现,这种累加形式导致导数也是累加形式,因此避免了梯度消失。

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
在线性 RNN 上通过时间反向传播计算梯度的方法称为 BPTT(Backpropagation Through Time)。BPTT 是一种递归算法,用于计算 RNN 模型中的梯度。 在 BPTT 中,我们将 RNN 展开成一个时间步长序列,每个时间步长都是一个相同的网络结构。我们首先将输入序列 $x_1, x_2, ..., x_T$ 通过 RNN 模型得到输出序列 $y_1, y_2, ..., y_T$。然后我们定义损失函数 $L(y_1, y_2, ..., y_T)$,并计算损失函数对每个时间步长的输出的梯度 $\frac{\partial L}{\partial y_t}$。 接下来,我们使用链式法则计算每个时间步长的梯度。对于每个时间步长 $t$,我们需要计算 $\frac{\partial L}{\partial y_t}$,$\frac{\partial y_t}{\partial h_t}$ 和 $\frac{\partial h_t}{\partial h_{t-1}}$,其中 $h_t$ 是时间步长 $t$ 的隐藏状态。 $\frac{\partial L}{\partial y_t}$ 可以通过损失函数的定义直接计算。$\frac{\partial y_t}{\partial h_t}$ 和 $\frac{\partial h_t}{\partial h_{t-1}}$ 则可以通过 RNN 模型的前向传播和反向传播计算得到。然后我们可以使用链式法则将这些梯度相乘,计算出 $\frac{\partial L}{\partial h_{t-1}}$。这个过程可以一直往前传递,直到时间步长 $1$。 最后,我们可以使用这些梯度来更新模型的参数。具体地,我们可以使用随机梯度下降等优化算法来更新参数,以最小化损失函数。 总的来说,BPTT 是一种有效的算法,可以用于训练 RNN 模型。然而,由于 RNN 的时间步长可能很大,BPTT 很容易导致梯度消失梯度爆炸问题。因此,我们需要采取一些技巧来解决这些问题,例如剪枝梯度、使用 LSTM 等。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值