其实挺简单:RNN的反向传播

1、一个结构和两个公式

在这里插入图片描述
s t = f ( U ⋅ x t + W ⋅ s t − 1 ) s_t=f(U\cdot x_t+W\cdot s_{t-1}) st=f(Uxt+Wst1)
o t = g ( V ⋅ s t ) o_t=g(V\cdot s_t) ot=g(Vst)

2、简要说明反向传播

    在训练过程中,我们会进行前向传播,得到每一时刻的预测值,在RNN中预测值就是 o t o_t ot,这些预测值 o t o_t ot和真实值 y t y_t yt之间存在误差,利用这一点我们可以想办法构造出损失函数,常见的是交叉熵损失,则对于每一时刻的损失我们设为: L t = f ( o t , y t ) L_t=f(o_t,y_t) Lt=f(ot,yt),这里的f只是一个代号,表示一个函数,希望你能理解这种抽象的表达,这和上面的两个公式道理一样。
    在有了每一时刻的损失函数之后,我们需要对参数进行更新,更新的目的是为了最小化损失函数,一般做法就是梯度下降,即损失函数对每一个参数求偏导,再沿着偏导方向去更新参数。这里会有两个深度学习中常见的问题:梯度消失和梯度爆炸。当我们的损失函数对某个参数求偏导,得到的导数有可能非常非常小,也有可能非常非常大,这是可能的,我们在RNN中会更常见这两种现象,主要是由于RNN的结构问题。梯度消失的表现,当导数非常小的时候,参数更新幅度就特别小,近乎没有,也就是梯度更新了,参数却基本保持不变。参数不变,则预测值不变;预测值不变,则准确率不会改变。梯度爆炸的表现,导数很大,则参数更新幅度很大,甚至会出现来回震荡,不收敛的现象。
    关于这两种现象的一些解决办法呢?对于梯度爆炸,我们可以通过梯度裁剪的方式去考虑,就是给参数进行梯度更新这个操作一个限制,只有在梯度小于某个值的时候,才进行参数更新。对于梯度消失,我们已经知道梯度消失,是因为梯度太小,即导数太小导致参数更新基本保持不变,第一种做法是增大学习率,使参数能够正常更新;第二种做法是减少网络深度,这是因为网络层数很多时,损失函数对开始几层参数的导数通常是很多导数的相乘,一般很容易出现梯度消失现象。
    言归正传,反向传播是训练过程中来更新网络参数的,两个步骤,计算梯度,更新参数。

3、手动推导一下

    只当RNN有三个权重参数,不考虑偏置项,U,V,W,重点说明RNN网络的损失函数对W的偏导,因为这是RNN出现梯度消失或梯度爆炸的根本原因。注意,这里说的RNN网络的损失函数是只网络的整体损失,对应了每一个时刻的损失之和。

(1)考虑第一个时刻的损失 L 1 L_1 L1对W的偏导

∂ L 1 ∂ W = ∂ L 1 ∂ o 1 ⋅ ∂ o 1 ∂ s 1 ⋅ ∂ + s 1 ∂ W \frac{\partial L_1}{\partial W}=\frac{\partial L_1}{\partial o_1} \cdot \frac{\partial o_1}{\partial s_1}\cdot \frac{\partial ^+s_1}{\partial W} WL1=o1L1s1o1W+s1

这里需要说明的有两点,“1、与求导有关的 量,作为变量”;“2、求导结束后,没有变量,视为终止”;从这两点说明一下上面的式子,我们有如上链式求导的原因是, L 1 L_1 L1是由 o 1 o_1 o1来的, o 1 o_1 o1是由 s 1 s_1 s1来的,而 s 1 s_1 s1中包含了求导变量W,所以其上游的 o 1 o_1 o1 s 1 s_1 s1均作为变量。终止条件是在 s 1 s_1 s1对W求偏导之后,不存在与W相关的变量了,所以终止。

(2)考虑第二个时刻的损失 L 2 L_2 L2对W的偏导

∂ L 2 ∂ W = ∂ L 2 ∂ o 2 ⋅ ∂ o 2 ∂ s 2 ⋅ ∂ s 2 ∂ W = ∂ L 2 ∂ o 2 ⋅ ∂ o 2 ∂ s 2 ⋅ ( ∂ s 2 ∂ s 1 ⋅ ∂ + s 1 ∂ W + ∂ + s 2 ∂ W ) \frac{\partial L_2}{\partial W}=\frac{\partial L_2}{\partial o_2} \cdot \frac{\partial o_2}{\partial s_2}\cdot \frac{\partial s_2}{\partial W}=\frac{\partial L_2}{\partial o_2} \cdot \frac{\partial o_2}{\partial s_2}\cdot (\frac{\partial s_2}{\partial s_1}\cdot \frac{\partial ^+s_1}{\partial W}+\frac{\partial ^+s_2}{\partial W}) WL2=o2L2s2o2Ws2=o2L2s2o2s1s2W+s1W+s2

需要注意的是 s 2 = f ( U ⋅ x t + W ⋅ s 1 ) s_2=f(U\cdot x_t+W\cdot s_{1}) s2=f(Uxt+Ws1),所以在第二步等于的时候,不要觉得奇怪。类似的例子 g ( x ) = x g(x)=x g(x)=x y = x ⋅ g ( x ) y=x\cdot g(x) y=xg(x),那么 ∂ y ∂ x = ∂ y ∂ x + ∂ y ∂ g ⋅ ∂ g ∂ x \frac{\partial y}{\partial x}=\frac{\partial y}{\partial x} +\frac{\partial y}{\partial g}\cdot \frac{\partial g}{\partial x} xy=xy+gyxg。 这里的写法上有问题,但过程是这么回事。也可以这么考虑,等于号之前 ∂ s 2 ∂ W \frac{\partial s_2}{\partial W} Ws2是延续的求导,等于号之后 ∂ + s 2 ∂ W \frac{\partial ^+s_2}{\partial W} W+s2只考虑当前的时间戳;

(3)考虑第三个时刻的损失 L 3 L_3 L3对W的偏导

∂ L 3 ∂ W = ∂ L 3 ∂ o 3 ⋅ ∂ o 3 ∂ s 3 ⋅ ∂ s 3 ∂ W = ∂ L 3 ∂ o 3 ⋅ ∂ o 3 ∂ s 3 ⋅ ( ∂ s 3 ∂ s 2 ⋅ ∂ s 2 ∂ W + ∂ + s 3 ∂ W ) = ∂ L 3 ∂ o 3 ⋅ ∂ o 3 ∂ s 3 ⋅ ( ∂ s 3 ∂ s 2 ⋅ ∂ s 2 ∂ s 1 ⋅ ∂ + s 1 ∂ W + ∂ s 3 ∂ s 2 ⋅ ∂ + s 2 ∂ W + ∂ + s 3 ∂ W ) \frac{\partial L_3}{\partial W}=\frac{\partial L_3}{\partial o_3} \cdot \frac{\partial o_3}{\partial s_3}\cdot \frac{\partial s_3}{\partial W}=\frac{\partial L_3}{\partial o_3} \cdot \frac{\partial o_3}{\partial s_3}\cdot (\frac{\partial s_3}{\partial s_2}\cdot \frac{\partial s_2}{\partial W}+\frac{\partial ^+s_3}{\partial W})=\frac{\partial L_3}{\partial o_3} \cdot \frac{\partial o_3}{\partial s_3}\cdot (\frac{\partial s_3}{\partial s_2}\cdot \frac{\partial s_2}{\partial s_1}\cdot \frac{\partial ^+s_1}{\partial W}+\frac{\partial s_3}{\partial s_2}\cdot \frac{\partial ^+s_2}{\partial W}+ \frac{\partial ^+s_3}{\partial W}) WL3=o3L3s3o3Ws3=o3L3s3o3s2s3Ws2W+s3=o3L3s3o3s2s3s1s2W+s1s2s3W+s2W+s3

(4)考虑 t 时刻的损失 L t L_t Lt对W的偏导,从上面总结一下规律,有

∂ L t ∂ W = ∂ L t ∂ o t ⋅ ∂ o t ∂ s t ⋅ ∂ s t ∂ W = ∂ L t ∂ o t ⋅ ∂ o t ∂ s t ⋅ ∑ i = 1 t ∂ s t ∂ s i ⋅ ∂ + s i ∂ W \frac{\partial L_t}{\partial W}=\frac{\partial L_t}{\partial o_t} \cdot \frac{\partial o_t}{\partial s_t}\cdot \frac{\partial s_t}{\partial W}=\frac{\partial L_t}{\partial o_t} \cdot \frac{\partial o_t}{\partial s_t}\cdot \sum_{i=1}^t\frac{\partial s_t}{\partial s_i} \cdot \frac{\partial ^+s_i}{\partial W} WLt=otLtstotWst=otLtstoti=1tsistW+si

可以看到写成这样简洁多了,其中有一个 ∂ s t ∂ s i \frac{\partial s_t}{\partial s_i} sist是造成梯度消失或梯度爆炸的罪魁祸首,因为这个式子展开之后:

∂ s t ∂ s i = ∏ k = i t − 1 ∂ s k + 1 ∂ s k \frac{\partial s_t}{\partial s_i} = \prod_{k=i}^{t-1}\frac{\partial s_{k+1}}{\partial s_k} sist=k=it1sksk+1

这是多个偏导连乘的运算,就可能出现梯度消失或梯度爆炸的情况。

(5)注意上述的公式是某一时刻损失对W的偏导,不是总体损失对W的偏导,理解到此应该差不多了。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值