RNN求导公式详细推导

本菜鸡觉得RNN求导公式太复杂了, 所以想了一个办法拆分求导的公式.

那就是用语法树.

原文参见RNN反向求导详解_格物致知-CSDN博客

RNN正向传播结构

o t = φ ( V s t ) = φ ( V ϕ ( W s t − 1 + U x t ) ) L t = loss ( o t , y t ) \begin{aligned} o_t&=\varphi(Vs_t)=\varphi(V\phi(Ws_{t-1}+Ux_t))\\ L_t&=\text{loss}(o_t,y_t) \end{aligned} otLt=φ(Vst)=φ(Vϕ(Wst1+Uxt))=loss(ot,yt)
o t ∗ = V s t o_t^*=Vs_t ot=Vst, s t ∗ = U x t + W s t − 1 s_t^*=Ux_t+Ws_{t-1} st=Uxt+Wst1

o t = φ ( o t ∗ ) o_t=\varphi(o_t^*) ot=φ(ot), s t = ϕ ( s t ∗ ) s_t=\phi(s_t^*) st=ϕ(st)

现在把 L t L_t Lt画成一棵语法树, 然后开始一步一步求导

语法树

∗ * 表示元素相乘, 用 × \times ×表示矩阵乘法
∂ L t ∂ o t ∗ = ∂ L t ∂ o t ∗ ∂ o t ∂ o t ∗ = ∂ L t ∂ o t ∗ φ ′ ( o t ∗ ) (1) \begin{aligned} \cfrac{\partial L_t}{\partial o_t^*}=\cfrac{\partial L_t}{\partial o_t}*\cfrac{\partial o_t}{\partial o_t^*}=\cfrac{\partial L_t}{\partial o_t}*\varphi'(o_t^*)\tag{1} \end{aligned} otLt=otLtotot=otLtφ(ot)(1)
式1的结果是一个与 o t ∗ o_t^* ot的维度一致的向量.
∂ L t ∂ V t = ∂ L t ∂ o t ∗ [ ? ] ∂ o t ∗ ∂ V (2) \cfrac{\partial L_t}{\partial V_t}=\cfrac{\partial L_t}{\partial o_t^*}[?]\cfrac{\partial o_t^*}{\partial V}\tag{2} VtLt=otLt[?]Vot(2)
公式2整体上是标量对矩阵求导, 标量对矩阵求导就是标量对矩阵中的每个元素求导; 有一个中间值 o t ∗ o_t^* ot是向量.

的前半部分在公式1中求过了, 后面是对矩阵×向量的求导

既然是对 V V V求导那结果的形状必然跟 V V V一样

还是写个例子算算怎么求导吧

o ∗ = V × s = [ V 11 V 12 V 13 V 14 V 21 V 22 V 23 V 24 V 31 V 32 V 33 V 34 ] × [ s 1 s 2 s 3 s 4 ] = [ V 11 s 1 + V 12 s 2 + V 13 s 3 + V 14 s 4 V 21 s 1 + V 22 s 2 + V 23 s 3 + V 24 s 4 V 31 s 1 + V 32 s 2 + V 33 s 3 + V 34 s 4 ] = [ o 1 ∗ o 2 ∗ o 3 ∗ ] (3) \boldsymbol{o^*}=\boldsymbol{V}\times\boldsymbol{s}= \begin{bmatrix}V_{11}&V_{12}&V_{13}&V_{14}\\V_{21}&V_{22}&V_{23}&V_{24}\\V_{31}&V_{32}&V_{33}&V_{34}\end{bmatrix} \times \begin{bmatrix}s_1\\s_2\\s_3\\s_4\end{bmatrix}= \begin{bmatrix}V_{11}s_1+V_{12}s_2+V_{13}s_3+V_{14}s_4\\V_{21}s_1+V_{22}s_2+V_{23}s_3+V_{24}s_4\\V_{31}s_1+V_{32}s_{2}+V_{33}s_3+V_{34}s_4\end{bmatrix}= \begin{bmatrix}o^*_1\\o^*_2\\o^*_3\end{bmatrix}\tag{3} o=V×s=V11V21V31V12V22V32V13V23V33V14V24V34×s1s2s3s4=V11s1+V12s2+V13s3+V14s4V21s1+V22s2+V23s3+V24s4V31s1+V32s2+V33s3+V34s4=o1o2o3(3)
∂ L ∂ V 11 = ∂ L ∂ o 1 ∗ ∂ o 1 ∗ ∂ V 11 = ∂ L ∂ o 1 ∗ s 1 ∂ L ∂ V 12 = ∂ L ∂ o 1 ∗ ∂ o 1 ∗ ∂ V 12 = ∂ L ∂ o 1 ∗ s 2 ⋮ ∂ L ∂ V 34 = ∂ L ∂ o 3 ∗ ∂ o 3 ∗ ∂ V 34 = ∂ L ∂ o 3 ∗ s 4 \begin{aligned} \cfrac{\partial L}{\partial V_{11}}=\cfrac{\partial L}{\partial o^*_1}\cfrac{\partial o^*_1}{\partial V_{11}}&=\cfrac{\partial L}{\partial o^*_1}s_1\\ \cfrac{\partial L}{\partial V_{12}}=\cfrac{\partial L}{\partial o^*_1}\cfrac{\partial o^*_1}{\partial V_{12}}&=\cfrac{\partial L}{\partial o^*_1}s_2\\ &\vdots\\ \cfrac{\partial L}{\partial V_{34}}=\cfrac{\partial L}{\partial o^*_3}\cfrac{\partial o^*_3}{\partial V_{34}}&=\cfrac{\partial L}{\partial o^*_3}s_4\\ \end{aligned} V11L=o1LV11o1V12L=o1LV12o1V34L=o3LV34o3=o1Ls1=o1Ls2=o3Ls4
∂ L ∂ V = [ ∂ L ∂ o 1 ∗ ∂ L ∂ o 2 ∗ ∂ L ∂ o 3 ∗ ] × [ s 1 s 2 s 3 s 4 ] \begin{aligned} \cfrac{\partial L}{\partial V}=\begin{bmatrix}\cfrac{\partial L}{\partial o^*_1}\\\cfrac{\partial L}{\partial o^*_2}\\\cfrac{\partial L}{\partial o^*_3}\end{bmatrix}\times\begin{bmatrix}s_1&s_2&s_3&s_4\end{bmatrix} \end{aligned} VL=o1Lo2Lo3L×[s1s2s3s4]
所以式2应该写成
∂ L t ∂ V t = ∂ L t ∂ o t ∗ × ∂ o t ∗ ∂ V = ∂ L t ∂ o t ∗ × s t T (4) \cfrac{\partial L_t}{\partial V_t}=\cfrac{\partial L_t}{\partial o_t^*}\times\cfrac{\partial o_t^*}{\partial V}=\cfrac{\partial L_t}{\partial o_t^*}\times s_t^T\tag{4} VtLt=otLt×Vot=otLt×stT(4)
然后求 L t L_t Lt s t s_t st的导数, 还要参考式3

为什么好像没有学过这个链式法则的样子
图片来源: https://wenku.baidu.com/view/0c28ff2249d7c1c708a1284ac850ad02de8007c1.html

∂ L ∂ s 1 = [ ∂ L ∂ o 1 ∗ ∂ o 1 ∗ ∂ s 1 + ∂ L ∂ o 2 ∗ ∂ o 2 ∗ ∂ s 1 + ∂ L ∂ o 3 ∗ ∂ o 3 ∗ ∂ s 1 ] ⋮ ∂ L ∂ s 4 = [ ∂ L ∂ o 1 ∗ ∂ o 1 ∗ ∂ s 4 + ∂ L ∂ o 2 ∗ ∂ o 2 ∗ ∂ s 4 + ∂ L ∂ o 3 ∗ ∂ o 3 ∗ ∂ s 4 ] \begin{aligned} \cfrac{\partial L}{\partial s_1}&=\begin{bmatrix}\cfrac{\partial L}{\partial o_1^*}\cfrac{\partial o_1^*}{\partial s_1} + \cfrac{\partial L}{\partial o_2^*}\cfrac{\partial o_2^*}{\partial s_1} + \cfrac{\partial L}{\partial o_3^*}\cfrac{\partial o_3^*}{\partial s_1}\end{bmatrix}\\ &\vdots\\ \cfrac{\partial L}{\partial s_4}&=\begin{bmatrix}\cfrac{\partial L}{\partial o_1^*}\cfrac{\partial o_1^*}{\partial s_4} + \cfrac{\partial L}{\partial o_2^*}\cfrac{\partial o_2^*}{\partial s_4} + \cfrac{\partial L}{\partial o_3^*}\cfrac{\partial o_3^*}{\partial s_4}\end{bmatrix}\\ \end{aligned} s1Ls4L=[o1Ls1o1+o2Ls1o2+o3Ls1o3]=[o1Ls4o1+o2Ls4o2+o3Ls4o3]
∂ L ∂ s = [ ∂ L ∂ o 1 ∗ ∂ o 1 ∗ ∂ s 1 + ∂ L ∂ o 2 ∗ ∂ o 2 ∗ ∂ s 1 + ∂ L ∂ o 3 ∗ ∂ o 3 ∗ ∂ s 1 ∂ L ∂ o 1 ∗ ∂ o 1 ∗ ∂ s 2 + ∂ L ∂ o 2 ∗ ∂ o 2 ∗ ∂ s 2 + ∂ L ∂ o 3 ∗ ∂ o 3 ∗ ∂ s 2 ∂ L ∂ o 1 ∗ ∂ o 1 ∗ ∂ s 3 + ∂ L ∂ o 2 ∗ ∂ o 2 ∗ ∂ s 3 + ∂ L ∂ o 3 ∗ ∂ o 3 ∗ ∂ s 3 ∂ L ∂ o 1 ∗ ∂ o 1 ∗ ∂ s 4 + ∂ L ∂ o 2 ∗ ∂ o 2 ∗ ∂ s 4 + ∂ L ∂ o 3 ∗ ∂ o 3 ∗ ∂ s 4 ] = [ ∂ o 1 ∗ ∂ s 1 ∂ o 2 ∗ ∂ s 1 ∂ o 3 ∗ ∂ s 1 ∂ o 1 ∗ ∂ s 2 ∂ o 2 ∗ ∂ s 2 ∂ o 3 ∗ ∂ s 2 ∂ o 1 ∗ ∂ s 3 ∂ o 2 ∗ ∂ s 3 ∂ o 3 ∗ ∂ s 3 ∂ o 1 ∗ ∂ s 4 ∂ o 2 ∗ ∂ s 4 ∂ o 3 ∗ ∂ s 4 ] × [ ∂ L ∂ o 1 ∗ ∂ L ∂ o 2 ∗ ∂ L ∂ o 3 ∗ ] = ? × ∂ L ∂ o t (5) \begin{aligned} \cfrac{\partial L}{\partial s}&=\begin{bmatrix}\cfrac{\partial L}{\partial o_1^*}\cfrac{\partial o_1^*}{\partial s_1} + \cfrac{\partial L}{\partial o_2^*}\cfrac{\partial o_2^*}{\partial s_1} + \cfrac{\partial L}{\partial o_3^*}\cfrac{\partial o_3^*}{\partial s_1}\\\cfrac{\partial L}{\partial o_1^*}\cfrac{\partial o_1^*}{\partial s_2} + \cfrac{\partial L}{\partial o_2^*}\cfrac{\partial o_2^*}{\partial s_2} + \cfrac{\partial L}{\partial o_3^*}\cfrac{\partial o_3^*}{\partial s_2}\\\cfrac{\partial L}{\partial o_1^*}\cfrac{\partial o_1^*}{\partial s_3} + \cfrac{\partial L}{\partial o_2^*}\cfrac{\partial o_2^*}{\partial s_3} + \cfrac{\partial L}{\partial o_3^*}\cfrac{\partial o_3^*}{\partial s_3}\\\cfrac{\partial L}{\partial o_1^*}\cfrac{\partial o_1^*}{\partial s_4} + \cfrac{\partial L}{\partial o_2^*}\cfrac{\partial o_2^*}{\partial s_4} + \cfrac{\partial L}{\partial o_3^*}\cfrac{\partial o_3^*}{\partial s_4}\end{bmatrix}\\&= \begin{bmatrix}\cfrac{\partial o_1^*}{\partial s_1}&\cfrac{\partial o_2^*}{\partial s_1}&\cfrac{\partial o_3^*}{\partial s_1}\\\cfrac{\partial o_1^*}{\partial s_2}&\cfrac{\partial o_2^*}{\partial s_2}&\cfrac{\partial o_3^*}{\partial s_2}\\\cfrac{\partial o_1^*}{\partial s_3}&\cfrac{\partial o_2^*}{\partial s_3}&\cfrac{\partial o_3^*}{\partial s_3}\\\cfrac{\partial o_1^*}{\partial s_4}&\cfrac{\partial o_2^*}{\partial s_4}&\cfrac{\partial o_3^*}{\partial s_4}\end{bmatrix}\times\begin{bmatrix}\cfrac{\partial L}{\partial o_1^*}\\\cfrac{\partial L}{\partial o_2^*}\\\cfrac{\partial L}{\partial o_3^*}\end{bmatrix} \\&=?\times{\partial L \over \partial o_t}\tag{5} \end{aligned} sL=o1Ls1o1+o2Ls1o2+o3Ls1o3o1Ls2o1+o2Ls2o2+o3Ls2o3o1Ls3o1+o2Ls3o2+o3Ls3o3o1Ls4o1+o2Ls4o2+o3Ls4o3=s1o1s2o1s3o1s4o1s1o2s2o2s3o2s4o2s1o3s2o3s3o3s4o3×o1Lo2Lo3L=?×otL(5)
要解决式5的后一步, 需要先向量求导的问题

参考链接: https://zhuanlan.zhihu.com/p/36448789

文中有一句话:

不过为了方便我们在实践中应用,通常情况下即使 y y y向量是列向量也按照行向量来进行求导。

根据这句话可以得出, 一般情况下是行向量对列向量求导.

行向量 X X X对列向量 Y Y Y求导会形成一个矩阵, 矩阵的宽度是 X X X的长度, 矩阵的高度是 Y Y Y的长度

所以式5中的问号矩阵应该是一个行向量 o t ∗ o_t^* ot对列向量 s s s求导
∂ L ∂ s t = ∂ o t ∗ ∂ s t × ∂ L ∂ o t ∗ (6) \cfrac{\partial L}{\partial s_t}=\cfrac{\partial o_t^*}{\partial s_t}\times\cfrac{\partial L}{\partial o_t^*}\tag{6} stL=stot×otL(6)
式6中的 ∂ o t ∗ ∂ s t \cfrac{\partial o_t^*}{\partial s_t} stot还可以继续求出结果
∂ o t ∗ ∂ s t = [ ∂ o 1 ∗ ∂ s 1 ∂ o 2 ∗ ∂ s 1 ∂ o 3 ∗ ∂ s 1 ∂ o 1 ∗ ∂ s 2 ∂ o 2 ∗ ∂ s 2 ∂ o 3 ∗ ∂ s 2 ∂ o 1 ∗ ∂ s 3 ∂ o 2 ∗ ∂ s 3 ∂ o 3 ∗ ∂ s 3 ∂ o 1 ∗ ∂ s 4 ∂ o 2 ∗ ∂ s 4 ∂ o 3 ∗ ∂ s 4 ] = [ V 11 V 21 V 31 V 12 V 22 V 32 V 13 V 23 V 33 V 14 V 24 V 34 ] = V T \begin{aligned} \cfrac{\partial o_t^*}{\partial s_t}&=\begin{bmatrix}\cfrac{\partial o_1^*}{\partial s_1}&\cfrac{\partial o_2^*}{\partial s_1}&\cfrac{\partial o_3^*}{\partial s_1}\\\cfrac{\partial o_1^*}{\partial s_2}&\cfrac{\partial o_2^*}{\partial s_2}&\cfrac{\partial o_3^*}{\partial s_2}\\\cfrac{\partial o_1^*}{\partial s_3}&\cfrac{\partial o_2^*}{\partial s_3}&\cfrac{\partial o_3^*}{\partial s_3}\\\cfrac{\partial o_1^*}{\partial s_4}&\cfrac{\partial o_2^*}{\partial s_4}&\cfrac{\partial o_3^*}{\partial s_4}\end{bmatrix}\\ &=\begin{bmatrix}V_{11}&V_{21}&V_{31}\\V_{12}&V_{22}&V_{32}\\V_{13}&V_{23}&V_{33}\\V_{14}&V_{24}&V_{34}\end{bmatrix}\\ &=V^T \end{aligned} stot=s1o1s2o1s3o1s4o1s1o2s2o2s3o2s4o2s1o3s2o3s3o3s4o3=V11V12V13V14V21V22V23V24V31V32V33V34=VT
上面的结果带入式6中得到
∂ L ∂ s t = ∂ o t ∗ ∂ s t × ∂ L ∂ o t ∗ = V T × ∂ L ∂ o t ∗ (7) \cfrac{\partial L}{\partial s_t}=\cfrac{\partial o_t^*}{\partial s_t}\times\cfrac{\partial L}{\partial o_t^*}=V^T\times\cfrac{\partial L}{\partial o_t^*}\tag{7} stL=stot×otL=VT×otL(7)
到此为止, 所以涉及到的技术都已经写完了, 把求导结果都填到语法树上后
更新后的语法树

分析后面发现, 后面的结构都是对前面的规律的简单重复.

所以后面随便填两个吧!
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值