rnn反向传播求导详尽过程及思路

之前对于rnn反向传播求导的过程一直不是特别的熟悉,最近深入理解了一下rnn的反向传播推导过程,特总结在此,以便学习。
首先先给出一张经典的rnn反向传播流程的图片
rnn反向传播图片写出里面各项内容的相应的关系
o t ∗ = φ ( V s t ) = φ ( V ϕ ( U x t + W s t − 1 ) ) o_{t}^{*} = \varphi(Vs_{t}) = \varphi(V\phi(Ux_{t}+Ws_{t-1})) ot=φ(Vst)=φ(Vϕ(Uxt+Wst1))
o t ∗ = V s t , s t ∗ = U x t + W s t − 1 o_{t}^{*} = Vs_{t},s_{t}^{*} = Ux_{t}+Ws_{t-1} ot=Vst,st=Uxt+Wst1
(多说一句,这里的 s t ∗ = U x t + W s t − 1 s_{t}^{*} = Ux_{t}+Ws_{t-1} st=Uxt+Wst1实际上就是pytorch当中的nn.rnn实现的矩阵相乘的过程)
则有:
o t = φ ( o t ∗ ) , s t = ϕ ( s t ∗ ) o_{t} = \varphi(o_{t}^{*}),s_{t} = \phi(s^{*}_{t}) ot=φ(ot),st=ϕ(st)
接下来的求导公式主要利用节点的转移进行求导的计算,下面的"*"表示数值相乘的计算(主要用于对激活函数的求导),“X"表示矩阵相乘的计算(主要针对于非激活函数的求导)
利用激活函数前后的中间变量求导 ∂ L t ∂ O t ∗ = ∂ L t ∂ O t × ∂ O t ∂ O t ∗ = ∂ L t ∂ O t × φ ′ ( O t ∗ ) \frac{\partial L_{t}}{\partial O_{t}^{*}} = \frac{\partial L_{t}}{\partial O_{t}}\times \frac{\partial{O_{t}}}{\partial{O_{t}^{*}}}= \frac{\partial L_{t}}{\partial O_{t}} \times \varphi^{'}(O_{t}^{*}) OtLt=OtLt×OtOt=OtLt×φ(Ot)
(这里借助激活函数之后的 O t O_{t} Ot作为中间变量进行求解)
同理利用中间变量求导可得
∂ L t ∂ V = ∂ L t ∂ O t × ∂ O t ∂ O t ∗ ∗ ∂ O t ∗ ∂ V = ∂ L t ∂ O t ∗ ∗ ∂ O t ∗ ∂ V = ∂ L t ∂ O t × φ ′ ( O t ∗ ) ∗ ∂ O t ∗ ∂ V \frac{\partial L_{t}}{\partial V} = \frac{\partial L_{t}}{\partial O_{t}} \times \frac{\partial O_{t}}{\partial O_{t}^{*}} * \frac{\partial O_{t}^{*}}{\partial V} = \frac{\partial L_{t}}{\partial O_{t}^{*}} * \frac{\partial O_{t}^{*}}{\partial V} = \frac{\partial L_{t}}{\partial O_{t}} \times \varphi^{'}(O_{t}^{*}) * \frac{\partial O_{t}^{*}}{\partial V} VLt=OtLt×OtOtVOt=OtLtVOt=OtLt×φ(Ot)VOt
可见对矩阵V的分析即为普通的反向传播算法,相对而言比较平凡,由 L = ∑ t = 1 N ( ∂ L t ∂ O t × φ ′ ( O t ∗ ) ∗ ∂ O t ∗ ∂ V ) L = \sum_{t=1}^N(\frac{\partial L_{t}}{\partial O_{t}} \times \varphi^{'}(O_{t}^{*}) * \frac{\partial O_{t}^{*}}{\partial V}) L=t=1N(OtLt×φ(Ot)VOt)
但是由于RNN算法的主要难点在于它state之间的通信,亦即梯度除了按照空间结构传播 ( o t − > s t − > x t ) (o_{t}->s_{t}->x_{t}) (ot>st>xt)以外,还得沿着时间通道传播 ( s t − > s t − 1 − > . . . − > s 1 ) (s_{t}->s_{t-1}->...->s_{1}) (st>st1>...>s1),这导致我们比较难将相应的RNN的BP算法写成一个统一的形式,为此我们可以采用"循环"的方法来计算各个梯度。
由于是反向传播算法,所以t应从n开始降序循环至1,在此期间(若需要初始化,则初始化为0向量或0矩阵)
所以接下来需要计算时间通道上的"局部梯度”:
∂ L t ∂ s t ∗ = ( ∂ L t ∂ O t ∗ × ∂ O t ∗ ∂ O t ∗ ∂ O t ∂ S t ) ∗ ∂ S t ∂ S t ∗ \frac{\partial L_{t}}{\partial s_{t}^{*}} = (\frac{\partial L_{t}}{\partial O_{t}^{*}} \times \frac{\partial O_{t}^{*}}{\partial O_{t}} * \frac{\partial O_{t}}{\partial S_{t}}) * \frac{\partial S_{t}}{\partial S_{t}^{*}} stLt=(OtLt×OtOtStOt)StSt

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值