1、一个结构和两个公式
s
t
=
f
(
U
⋅
x
t
+
W
⋅
s
t
−
1
)
s_t=f(U\cdot x_t+W\cdot s_{t-1})
st=f(U⋅xt+W⋅st−1)
o
t
=
g
(
V
⋅
s
t
)
o_t=g(V\cdot s_t)
ot=g(V⋅st)
2、简要说明反向传播
在训练过程中,我们会进行前向传播,得到每一时刻的预测值,在RNN中预测值就是
o
t
o_t
ot,这些预测值
o
t
o_t
ot和真实值
y
t
y_t
yt之间存在误差,利用这一点我们可以想办法构造出损失函数,常见的是交叉熵损失,则对于每一时刻的损失我们设为:
L
t
=
f
(
o
t
,
y
t
)
L_t=f(o_t,y_t)
Lt=f(ot,yt),这里的f只是一个代号,表示一个函数,希望你能理解这种抽象的表达,这和上面的两个公式道理一样。
在有了每一时刻的损失函数之后,我们需要对参数进行更新,更新的目的是为了最小化损失函数,一般做法就是梯度下降,即损失函数对每一个参数求偏导,再沿着偏导方向去更新参数。这里会有两个深度学习中常见的问题:梯度消失和梯度爆炸。当我们的损失函数对某个参数求偏导,得到的导数有可能非常非常小,也有可能非常非常大,这是可能的,我们在RNN中会更常见这两种现象,主要是由于RNN的结构问题。梯度消失的表现,当导数非常小的时候,参数更新幅度就特别小,近乎没有,也就是梯度更新了,参数却基本保持不变。参数不变,则预测值不变;预测值不变,则准确率不会改变。梯度爆炸的表现,导数很大,则参数更新幅度很大,甚至会出现来回震荡,不收敛的现象。
关于这两种现象的一些解决办法呢?对于梯度爆炸,我们可以通过梯度裁剪的方式去考虑,就是给参数进行梯度更新这个操作一个限制,只有在梯度小于某个值的时候,才进行参数更新。对于梯度消失,我们已经知道梯度消失,是因为梯度太小,即导数太小导致参数更新基本保持不变,第一种做法是增大学习率,使参数能够正常更新;第二种做法是减少网络深度,这是因为网络层数很多时,损失函数对开始几层参数的导数通常是很多导数的相乘,一般很容易出现梯度消失现象。
言归正传,反向传播是训练过程中来更新网络参数的,两个步骤,计算梯度,更新参数。
3、手动推导一下
只当RNN有三个权重参数,不考虑偏置项,U,V,W,重点说明RNN网络的损失函数对W的偏导,因为这是RNN出现梯度消失或梯度爆炸的根本原因。注意,这里说的RNN网络的损失函数是只网络的整体损失,对应了每一个时刻的损失之和。
(1)考虑第一个时刻的损失 L 1 L_1 L1对W的偏导
∂ L 1 ∂ W = ∂ L 1 ∂ o 1 ⋅ ∂ o 1 ∂ s 1 ⋅ ∂ + s 1 ∂ W \frac{\partial L_1}{\partial W}=\frac{\partial L_1}{\partial o_1} \cdot \frac{\partial o_1}{\partial s_1}\cdot \frac{\partial ^+s_1}{\partial W} ∂W∂L1=∂o1∂L1⋅∂s1∂o1⋅∂W∂+s1
这里需要说明的有两点,“1、与求导有关的 量,作为变量”;“2、求导结束后,没有变量,视为终止”;从这两点说明一下上面的式子,我们有如上链式求导的原因是, L 1 L_1 L1是由 o 1 o_1 o1来的, o 1 o_1 o1是由 s 1 s_1 s1来的,而 s 1 s_1 s1中包含了求导变量W,所以其上游的 o 1 o_1 o1和 s 1 s_1 s1均作为变量。终止条件是在 s 1 s_1 s1对W求偏导之后,不存在与W相关的变量了,所以终止。
(2)考虑第二个时刻的损失 L 2 L_2 L2对W的偏导
∂ L 2 ∂ W = ∂ L 2 ∂ o 2 ⋅ ∂ o 2 ∂ s 2 ⋅ ∂ s 2 ∂ W = ∂ L 2 ∂ o 2 ⋅ ∂ o 2 ∂ s 2 ⋅ ( ∂ s 2 ∂ s 1 ⋅ ∂ + s 1 ∂ W + ∂ + s 2 ∂ W ) \frac{\partial L_2}{\partial W}=\frac{\partial L_2}{\partial o_2} \cdot \frac{\partial o_2}{\partial s_2}\cdot \frac{\partial s_2}{\partial W}=\frac{\partial L_2}{\partial o_2} \cdot \frac{\partial o_2}{\partial s_2}\cdot (\frac{\partial s_2}{\partial s_1}\cdot \frac{\partial ^+s_1}{\partial W}+\frac{\partial ^+s_2}{\partial W}) ∂W∂L2=∂o2∂L2⋅∂s2∂o2⋅∂W∂s2=∂o2∂L2⋅∂s2∂o2⋅(∂s1∂s2⋅∂W∂+s1+∂W∂+s2)
需要注意的是 s 2 = f ( U ⋅ x t + W ⋅ s 1 ) s_2=f(U\cdot x_t+W\cdot s_{1}) s2=f(U⋅xt+W⋅s1),所以在第二步等于的时候,不要觉得奇怪。类似的例子 g ( x ) = x g(x)=x g(x)=x, y = x ⋅ g ( x ) y=x\cdot g(x) y=x⋅g(x),那么 ∂ y ∂ x = ∂ y ∂ x + ∂ y ∂ g ⋅ ∂ g ∂ x \frac{\partial y}{\partial x}=\frac{\partial y}{\partial x} +\frac{\partial y}{\partial g}\cdot \frac{\partial g}{\partial x} ∂x∂y=∂x∂y+∂g∂y⋅∂x∂g。 这里的写法上有问题,但过程是这么回事。也可以这么考虑,等于号之前 ∂ s 2 ∂ W \frac{\partial s_2}{\partial W} ∂W∂s2是延续的求导,等于号之后 ∂ + s 2 ∂ W \frac{\partial ^+s_2}{\partial W} ∂W∂+s2只考虑当前的时间戳;
(3)考虑第三个时刻的损失 L 3 L_3 L3对W的偏导
∂ L 3 ∂ W = ∂ L 3 ∂ o 3 ⋅ ∂ o 3 ∂ s 3 ⋅ ∂ s 3 ∂ W = ∂ L 3 ∂ o 3 ⋅ ∂ o 3 ∂ s 3 ⋅ ( ∂ s 3 ∂ s 2 ⋅ ∂ s 2 ∂ W + ∂ + s 3 ∂ W ) = ∂ L 3 ∂ o 3 ⋅ ∂ o 3 ∂ s 3 ⋅ ( ∂ s 3 ∂ s 2 ⋅ ∂ s 2 ∂ s 1 ⋅ ∂ + s 1 ∂ W + ∂ s 3 ∂ s 2 ⋅ ∂ + s 2 ∂ W + ∂ + s 3 ∂ W ) \frac{\partial L_3}{\partial W}=\frac{\partial L_3}{\partial o_3} \cdot \frac{\partial o_3}{\partial s_3}\cdot \frac{\partial s_3}{\partial W}=\frac{\partial L_3}{\partial o_3} \cdot \frac{\partial o_3}{\partial s_3}\cdot (\frac{\partial s_3}{\partial s_2}\cdot \frac{\partial s_2}{\partial W}+\frac{\partial ^+s_3}{\partial W})=\frac{\partial L_3}{\partial o_3} \cdot \frac{\partial o_3}{\partial s_3}\cdot (\frac{\partial s_3}{\partial s_2}\cdot \frac{\partial s_2}{\partial s_1}\cdot \frac{\partial ^+s_1}{\partial W}+\frac{\partial s_3}{\partial s_2}\cdot \frac{\partial ^+s_2}{\partial W}+ \frac{\partial ^+s_3}{\partial W}) ∂W∂L3=∂o3∂L3⋅∂s3∂o3⋅∂W∂s3=∂o3∂L3⋅∂s3∂o3⋅(∂s2∂s3⋅∂W∂s2+∂W∂+s3)=∂o3∂L3⋅∂s3∂o3⋅(∂s2∂s3⋅∂s1∂s2⋅∂W∂+s1+∂s2∂s3⋅∂W∂+s2+∂W∂+s3)
(4)考虑 t 时刻的损失 L t L_t Lt对W的偏导,从上面总结一下规律,有
∂ L t ∂ W = ∂ L t ∂ o t ⋅ ∂ o t ∂ s t ⋅ ∂ s t ∂ W = ∂ L t ∂ o t ⋅ ∂ o t ∂ s t ⋅ ∑ i = 1 t ∂ s t ∂ s i ⋅ ∂ + s i ∂ W \frac{\partial L_t}{\partial W}=\frac{\partial L_t}{\partial o_t} \cdot \frac{\partial o_t}{\partial s_t}\cdot \frac{\partial s_t}{\partial W}=\frac{\partial L_t}{\partial o_t} \cdot \frac{\partial o_t}{\partial s_t}\cdot \sum_{i=1}^t\frac{\partial s_t}{\partial s_i} \cdot \frac{\partial ^+s_i}{\partial W} ∂W∂Lt=∂ot∂Lt⋅∂st∂ot⋅∂W∂st=∂ot∂Lt⋅∂st∂ot⋅∑i=1t∂si∂st⋅∂W∂+si
可以看到写成这样简洁多了,其中有一个 ∂ s t ∂ s i \frac{\partial s_t}{\partial s_i} ∂si∂st是造成梯度消失或梯度爆炸的罪魁祸首,因为这个式子展开之后:
∂ s t ∂ s i = ∏ k = i t − 1 ∂ s k + 1 ∂ s k \frac{\partial s_t}{\partial s_i} = \prod_{k=i}^{t-1}\frac{\partial s_{k+1}}{\partial s_k} ∂si∂st=∏k=it−1∂sk∂sk+1
这是多个偏导连乘的运算,就可能出现梯度消失或梯度爆炸的情况。
(5)注意上述的公式是某一时刻损失对W的偏导,不是总体损失对W的偏导,理解到此应该差不多了。