文章目录
矩阵求导方法
- 维度相容原则:假设每个中间变量量的维度都不不⼀一样,看怎么摆能把雅克比矩阵的维度摆成矩阵乘法规则允许的形式。只要把矩阵维度倒腾顺了了,公式也就对了了。
- 设有 f ( Y ) : R m × p → R f ( Y ) : \mathbb { R } ^ { m \times p } \rightarrow \mathbb { R } f(Y):Rm×p→R, Y = A X + B : R n × p → R m × p Y = A X + B : \mathbb { R } ^ { n \times p } \rightarrow \mathbb { R } ^ { m \times p } Y=AX+B:Rn×p→Rm×p,则 ∇ X f ( A X + B ) = A T ∇ Y f \nabla _ { X } f ( A X + B ) = A ^ { T } \nabla _ { Y } f ∇Xf(AX+B)=AT∇Yf,即 ∂ f ∂ X = A T ∂ f ∂ Y \frac { \partial f } { \partial X} = A ^ { T } \frac { \partial f} { \partial Y } ∂X∂f=AT∂Y∂f
- 设有 f ( Y ) : R m × p → R f ( Y ) : \mathbb { R } ^ { m \times p } \rightarrow \mathbb { R } f(Y):Rm×p→R, Y = X A + B : R m × n → R m × p Y = X A+ B : \mathbb { R } ^ { m \times n } \rightarrow \mathbb { R } ^ { m \times p } Y=XA+B:Rm×n→Rm×p,则 ∇ X f ( X A + B ) = ∇ Y f A T \nabla _ { X } f (XA + B ) = { \nabla _ { Y } f}A ^ { T } ∇Xf(XA+B)=∇YfAT,即 ∂ f ∂ X = ∂ f ∂ Y A T \frac { \partial f } { \partial X} = \frac { \partial f} { \partial Y }{A ^ { T } } ∂X∂f=∂Y∂fAT
证明
在前向传播过程中,X的shape(N,D),W的shape(D,C),Y=XW。现在,我们假设N = 2, D = 2, C = 3。那么
X = ( x 1 , 1 x 1 , 2 x 2 , 1 x 2 , 2 ) W = ( w 1 , 1 w 1 , 2 w 1 , 3 w 2 , 1 w 2 , 2 w 2 , 3 ) X = \left( \begin{array} { l l } { x _ { 1,1 } } & { x _ { 1,2 } } \\ { x _ { 2,1 } } & { x _ { 2,2 } } \end{array} \right) \qquad W = \left( \begin{array} { l l l } { w _ { 1,1 } } & { w _ { 1,2 } } & { w _ { 1,3 } } \\ { w _ { 2,1 } } & { w _ { 2,2 } } & { w _ { 2,3 } } \end{array} \right) X=(x1,1x2,1x1,2x2,2)W=(w1,1w2,1w1,2w2,2w1,3w2,3) Y = X W = ( x 1 , 1 w 1 , 1 + x 1 , 2 w 2 , 1 x 1 , 1 w 1 , 2 + x 1 , 2 w 2 , 2 x 1 , 1 w 1 , 3 + x 1 , 2 w 2 , 3 x 2 , 1 w 1 , 1 + x 2 , 2 w 2 , 1 x 2 , 1 w 1 , 2 + x 2 , 2 w 2 , 2 x 2 , 1 w 1 , 3 + x 2 , 2 w 2 , 3 ) Y = X W = \left( \begin{array} { l l } { x _ { 1,1 } w _ { 1,1 } + x _ { 1,2 } w _ { 2,1 } } & { x _ { 1,1 } w _ { 1,2 } + x _ { 1,2 } w _ { 2,2 } } & { x _ { 1,1 } w _ { 1,3 } + x _ { 1,2 } w _ { 2,3 } } \\ { x _ { 2,1 } w _ { 1,1 } + x _ { 2,2 } w _ { 2,1 } } & { x _ { 2,1 } w _ { 1,2 } + x _ { 2,2 } w _ { 2,2 } } & { x _ { 2,1 } w _ { 1,3 } + x _ { 2,2 } w _ { 2,3 } } \end{array} \right) Y=XW=(x1,1w1,1+x1,2w2,1x2,1w1,1+x2,2w2,1x1,1w1,2+x1,2w2,2x2,1w1,2+x2,2w2,2x1,1w1,3+x1,2w2,3x2,1w1,3+x2,2w2,3)在前向传播结束后,我们通过输出Y计算得到损失函数L,然后求得 ∂ L ∂ Y \frac { \partial L } { \partial Y } ∂Y∂L ∂ L ∂ Y = ( ∂ L ∂ y 1 , 1 ∂ L ∂ y 1 , 2 ∂ L ∂ y 1 , 3 ∂ L ∂ y 2 , 1 ∂ L ∂ y 2 , 2 ∂