之前对于rnn反向传播求导的过程一直不是特别的熟悉,最近深入理解了一下rnn的反向传播推导过程,特总结在此,以便学习。
首先先给出一张经典的rnn反向传播流程的图片
写出里面各项内容的相应的关系
o
t
∗
=
φ
(
V
s
t
)
=
φ
(
V
ϕ
(
U
x
t
+
W
s
t
−
1
)
)
o_{t}^{*} = \varphi(Vs_{t}) = \varphi(V\phi(Ux_{t}+Ws_{t-1}))
ot∗=φ(Vst)=φ(Vϕ(Uxt+Wst−1))
令
o
t
∗
=
V
s
t
,
s
t
∗
=
U
x
t
+
W
s
t
−
1
o_{t}^{*} = Vs_{t},s_{t}^{*} = Ux_{t}+Ws_{t-1}
ot∗=Vst,st∗=Uxt+Wst−1
(多说一句,这里的
s
t
∗
=
U
x
t
+
W
s
t
−
1
s_{t}^{*} = Ux_{t}+Ws_{t-1}
st∗=Uxt+Wst−1实际上就是pytorch当中的nn.rnn实现的矩阵相乘的过程)
则有:
o
t
=
φ
(
o
t
∗
)
,
s
t
=
ϕ
(
s
t
∗
)
o_{t} = \varphi(o_{t}^{*}),s_{t} = \phi(s^{*}_{t})
ot=φ(ot∗),st=ϕ(st∗)
接下来的求导公式主要利用节点的转移进行求导的计算,下面的"*"表示数值相乘的计算(主要用于对激活函数的求导),“X"表示矩阵相乘的计算(主要针对于非激活函数的求导)
∂
L
t
∂
O
t
∗
=
∂
L
t
∂
O
t
×
∂
O
t
∂
O
t
∗
=
∂
L
t
∂
O
t
×
φ
′
(
O
t
∗
)
\frac{\partial L_{t}}{\partial O_{t}^{*}} = \frac{\partial L_{t}}{\partial O_{t}}\times \frac{\partial{O_{t}}}{\partial{O_{t}^{*}}}= \frac{\partial L_{t}}{\partial O_{t}} \times \varphi^{'}(O_{t}^{*})
∂Ot∗∂Lt=∂Ot∂Lt×∂Ot∗∂Ot=∂Ot∂Lt×φ′(Ot∗)
(这里借助激活函数之后的
O
t
O_{t}
Ot作为中间变量进行求解)
同理利用中间变量求导可得
∂
L
t
∂
V
=
∂
L
t
∂
O
t
×
∂
O
t
∂
O
t
∗
∗
∂
O
t
∗
∂
V
=
∂
L
t
∂
O
t
∗
∗
∂
O
t
∗
∂
V
=
∂
L
t
∂
O
t
×
φ
′
(
O
t
∗
)
∗
∂
O
t
∗
∂
V
\frac{\partial L_{t}}{\partial V} = \frac{\partial L_{t}}{\partial O_{t}} \times \frac{\partial O_{t}}{\partial O_{t}^{*}} * \frac{\partial O_{t}^{*}}{\partial V} = \frac{\partial L_{t}}{\partial O_{t}^{*}} * \frac{\partial O_{t}^{*}}{\partial V} = \frac{\partial L_{t}}{\partial O_{t}} \times \varphi^{'}(O_{t}^{*}) * \frac{\partial O_{t}^{*}}{\partial V}
∂V∂Lt=∂Ot∂Lt×∂Ot∗∂Ot∗∂V∂Ot∗=∂Ot∗∂Lt∗∂V∂Ot∗=∂Ot∂Lt×φ′(Ot∗)∗∂V∂Ot∗
可见对矩阵V的分析即为普通的反向传播算法,相对而言比较平凡,由
L
=
∑
t
=
1
N
(
∂
L
t
∂
O
t
×
φ
′
(
O
t
∗
)
∗
∂
O
t
∗
∂
V
)
L = \sum_{t=1}^N(\frac{\partial L_{t}}{\partial O_{t}} \times \varphi^{'}(O_{t}^{*}) * \frac{\partial O_{t}^{*}}{\partial V})
L=∑t=1N(∂Ot∂Lt×φ′(Ot∗)∗∂V∂Ot∗)
但是由于RNN算法的主要难点在于它state之间的通信,亦即梯度除了按照空间结构传播
(
o
t
−
>
s
t
−
>
x
t
)
(o_{t}->s_{t}->x_{t})
(ot−>st−>xt)以外,还得沿着时间通道传播
(
s
t
−
>
s
t
−
1
−
>
.
.
.
−
>
s
1
)
(s_{t}->s_{t-1}->...->s_{1})
(st−>st−1−>...−>s1),这导致我们比较难将相应的RNN的BP算法写成一个统一的形式,为此我们可以采用"循环"的方法来计算各个梯度。
由于是反向传播算法,所以t应从n开始降序循环至1,在此期间(若需要初始化,则初始化为0向量或0矩阵)
所以接下来需要计算时间通道上的"局部梯度”:
∂
L
t
∂
s
t
∗
=
(
∂
L
t
∂
O
t
∗
×
∂
O
t
∗
∂
O
t
∗
∂
O
t
∂
S
t
)
∗
∂
S
t
∂
S
t
∗
\frac{\partial L_{t}}{\partial s_{t}^{*}} = (\frac{\partial L_{t}}{\partial O_{t}^{*}} \times \frac{\partial O_{t}^{*}}{\partial O_{t}} * \frac{\partial O_{t}}{\partial S_{t}}) * \frac{\partial S_{t}}{\partial S_{t}^{*}}
∂st∗∂Lt=(∂Ot∗∂Lt×∂Ot∂Ot∗∗∂St∂Ot)∗∂St∗∂St
rnn反向传播求导详尽过程及思路
最新推荐文章于 2023-01-17 10:34:53 发布