近期把DNN的反向传播又好好的研究了一下。之前一直有疑虑是因为很多文档里边出现 ∂ z ( l + 1 ) ∂ z ( l ) \frac{\partial z^{(l+1)}}{\partial z^{(l)}} ∂z(l)∂z(l+1)这种表达式,然后 z ( l + 1 ) z^{(l+1)} z(l+1)和 z ( l ) z^{(l)} z(l)还是矩阵,这下就变得非常烦人了,因为没有哪本数学书定义了矩阵对矩阵的导数。只有标量函数对矩阵,矩阵对标量,标量对向量,向量对标量以及向量对向量。所以我觉得有必要在好好把这块弄一下,写清楚。
首先是DNN的模型:
{
z
(
l
+
1
)
=
θ
(
l
+
1
)
⋅
a
(
l
)
+
b
(
l
+
1
)
⋅
1
T
,
a
(
l
)
=
g
(
z
(
l
+
1
)
)
,
l
=
0
,
1
,
2
,
…
,
N
J
=
J
(
a
(
N
)
)
\left\{ \begin{array}{ll} z^{(l+1)} &= \theta^{(l+1)}\cdot a^{(l)}+b^{(l+1)}\cdot\boldsymbol{1}^T, & \\ a^{(l)} &= g(z^{(l+1)}),& l=0,1,2,\dots,N \\ J&=J(a^{(N)}) & \end{array} \right.
⎩⎨⎧z(l+1)a(l)J=θ(l+1)⋅a(l)+b(l+1)⋅1T,=g(z(l+1)),=J(a(N))l=0,1,2,…,N
这里边,
a
(
0
)
=
X
a^{(0)}=X
a(0)=X也就是输入,
1
\boldsymbol{1}
1是列向量。然后:
X
=
(
∣
…
∣
X
1
…
X
m
∣
…
∣
)
X= \begin{pmatrix} | & \dots & | \\ X_1 & \dots & X_m \\ | & \dots & | \\ \end{pmatrix}
X=⎝⎛∣X1∣………∣Xm∣⎠⎞
也就是说,一共有m个样本。
通常的文章怎么描述的呢?定义 δ ( l ) = ∂ J ∂ z ( l ) \delta^{(l)}=\frac{\partial J}{\partial z^{(l)}} δ(l)=∂z(l)∂J,假如计算出了 δ ( l ) \delta^{(l)} δ(l)那么 ∂ J ∂ θ ( l ) = ∂ J ∂ z ( l ) ⋅ ∂ z ( l ) ∂ θ ( l ) \frac{\partial J}{\partial \theta^{(l)}}=\frac{\partial J}{\partial z^{(l)}}\cdot\frac{\partial z^{(l)}}{\partial \theta^{(l)}} ∂θ(l)∂J=∂z(l)∂J⋅∂θ(l)∂z(l),然后 ∂ J ∂ z ( l − 1 ) = ∂ J ∂ z ( l ) ⋅ ∂ z ( l ) ∂ z ( l − 1 ) \frac{\partial J}{\partial z^{(l-1)}}=\frac{\partial J}{\partial z^{(l)}}\cdot\frac{\partial z^{(l)}}{\partial z^{(l-1)}} ∂z(l−1)∂J=∂z(l)∂J⋅∂z(l−1)∂z(l),由于 ∂ J ∂ z ( N ) \frac{\partial J}{\partial z^{(N)}} ∂z(N)∂J很容易计算,所以后边递推就可以了。但是问题在于 ∂ z ( l ) ∂ z ( l − 1 ) \frac{\partial z^{(l)}}{\partial z^{(l-1)}} ∂z(l−1)∂z(l)到底是啥?雅可比矩阵吗? z ( l ) z^{(l)} z(l)和 z ( l − 1 ) z^{(l-1)} z(l−1)都是矩阵,没有一本数学书有这么直接写的。矩阵对矩阵的导数目前还处于undefined的状态。所以这个符号其实是没有严格定义的。只不过按照其他的方式推导出来后,结果看上去很像,所以就这么写了,但是如果真的较真说这个矩阵对矩阵的定义是什么怎么算,那就没法严格的说了。所以这篇文章就是仔细的把这块严格的做一下。
然后有几个公式定理需要推导一下,推到完了,很多东西就迎刃而解了。
f : R m × n ↦ R f:R^{m\times n}\mapsto R f:Rm×n↦R也就是一个矩阵的标量函数,那么若 g : R p × q ↦ R m × n g:R^{p\times q}\mapsto R^{m\times n} g:Rp×q↦Rm×n,那么复合函数: f ∘ g : R p × q ↦ R f\circ g:R^{p\times q}\mapsto R f∘g:Rp×q↦R,例如 f ( z ) , z = θ X f(z),\ z=\theta X f(z), z=θX,又如 f ( a ) , a = g ( z ) f(a),\ a=g(z) f(a), a=g(z)。在这种情况下,我们希望得到 ∂ f ∂ θ \frac{\partial f}{\partial \theta} ∂θ∂f或者 ∂ f ∂ z \frac{\partial f}{\partial z} ∂z∂f,该如何求解?其实这种情况,需要用到matrix vectorization和kronecker product,但是我们所遇到的恰好是线性变换和element-wise function,所以对于这两种情况,完全可以简化。
Lemma 1
若
g
g
g是一个矩阵左乘或者右乘,也就是
g
=
θ
X
g=\theta X
g=θX这种情况,那么有:(
[
⋅
]
i
,
j
[\cdot]_{i,j}
[⋅]i,j是取一个矩阵的第i行第j列的元素)
[
∂
f
∂
X
]
i
,
j
=
∑
m
∑
n
∂
f
∂
g
m
,
n
⋅
∂
g
m
,
n
∂
X
i
,
j
=
∑
m
∑
n
∂
f
∂
g
m
,
n
⋅
∂
∑
k
θ
m
,
k
X
k
,
n
∂
X
i
,
j
=
∑
m
∂
f
∂
g
m
,
j
⋅
θ
m
,
i
=
[
θ
T
⋅
∂
f
∂
g
]
i
,
j
\begin{aligned} \left[\frac{\partial f}{\partial X}\right]_{i,j} &=\sum_m{\sum_n{ \frac{\partial f}{\partial g_{m,n}}\cdot\frac{\partial g_{m,n}}{\partial X_{i,j}} }} \\ &=\sum_m{\sum_n{ \frac{\partial f}{\partial g_{m,n}}\cdot\frac{\partial \sum_k{\theta_{m,k}X_{k,n}} }{\partial X_{i,j}} }} \\ &=\sum_m{\frac{\partial f}{\partial g_{m,j}}\cdot\theta_{m,i}} \\ &= \left[\theta^T\cdot\frac{\partial f}{\partial g}\right]_{i,j} \end{aligned}
[∂X∂f]i,j=m∑n∑∂gm,n∂f⋅∂Xi,j∂gm,n=m∑n∑∂gm,n∂f⋅∂Xi,j∂∑kθm,kXk,n=m∑∂gm,j∂f⋅θm,i=[θT⋅∂g∂f]i,j
因此:
∂
f
∂
X
=
θ
T
⋅
∂
f
∂
g
\frac{\partial f}{\partial X}=\theta^T\cdot\frac{\partial f}{\partial g}
∂X∂f=θT⋅∂g∂f
其中第一个等号是全微分公式,第二个等号是矩阵乘法展开,第三个等号是因为
k
≠
i
,
n
≠
j
k\neq i,\ n\neq j
k=i, n=j时
∂
θ
m
,
k
X
k
,
n
∂
X
i
,
j
=
0
\frac{\partial \theta_{m,k}X_{k,n}}{\partial X_{i,j}} =0
∂Xi,j∂θm,kXk,n=0,最后一个等号就是矩阵乘法了。
同理:
∂
f
∂
θ
=
∂
f
∂
g
⋅
X
T
\frac{\partial f}{\partial \theta}=\frac{\partial f}{\partial g}\cdot X^T
∂θ∂f=∂g∂f⋅XT
Lemma 2
假如
g
g
g是一个非线性函数,但是是一个element-wise的函数,那么:
[
∂
f
∂
X
]
i
,
j
=
∑
m
∑
n
∂
f
∂
g
m
,
n
⋅
∂
g
m
,
n
∂
X
i
,
j
=
[
∂
f
∂
a
]
i
,
j
⋅
[
g
′
(
z
)
]
i
,
j
\begin{aligned} \left[\frac{\partial f}{\partial X}\right]_{i,j} &=\sum_m{\sum_n{ \frac{\partial f}{\partial g_{m,n}}\cdot\frac{\partial g_{m,n}}{\partial X_{i,j}} }} \\ &= \left[\frac{\partial f}{\partial a}\right]_{i,j}\cdot \left[g'(z)\right]_{i,j} \end{aligned}
[∂X∂f]i,j=m∑n∑∂gm,n∂f⋅∂Xi,j∂gm,n=[∂a∂f]i,j⋅[g′(z)]i,j
因此:
∂
f
∂
X
=
∂
f
∂
a
⊙
g
′
(
z
)
\frac{\partial f}{\partial X}=\frac{\partial f}{\partial a}\odot g'(z)
∂X∂f=∂a∂f⊙g′(z)
这里边
⊙
\odot
⊙是hardamard product,其实就是元素乘法。
有了Lemma 1和Lemma 2之后很多东西就迎刃而解了。定义
δ
(
l
)
=
∂
J
∂
z
(
l
)
\delta^{(l)}=\frac{\partial J}{\partial z^{(l)}}
δ(l)=∂z(l)∂J,而
z
(
l
)
=
θ
(
l
)
⋅
a
(
l
−
1
)
+
b
(
l
)
⋅
1
T
z^{(l)}= \theta^{(l)}\cdot a^{(l-1)}+b^{(l)}\cdot\boldsymbol{1}^T
z(l)=θ(l)⋅a(l−1)+b(l)⋅1T
那么显然:
∂
J
∂
θ
(
l
)
=
δ
(
l
)
⋅
(
a
(
l
−
1
)
)
T
∂
J
∂
b
(
l
)
=
δ
(
l
)
⋅
1
\begin{aligned} \frac{\partial J}{\partial \theta^{(l)}}&=\delta^{(l)}\cdot (a^{(l-1)})^T \\ \frac{\partial J}{\partial b^{(l)}}&=\delta^{(l)}\cdot \boldsymbol{1} \end{aligned}
∂θ(l)∂J∂b(l)∂J=δ(l)⋅(a(l−1))T=δ(l)⋅1
那么对于有了
δ
(
l
+
1
)
\delta^{(l+1)}
δ(l+1)计算
δ
(
l
)
\delta^{(l)}
δ(l)呢?首先由于
z
(
l
+
1
)
=
θ
(
l
+
1
)
⋅
a
(
l
)
+
b
(
l
+
1
)
⋅
1
T
z^{(l+1)} = \theta^{(l+1)}\cdot a^{(l)}+b^{(l+1)}\cdot\boldsymbol{1}^T
z(l+1)=θ(l+1)⋅a(l)+b(l+1)⋅1T,所以:
∂
J
∂
a
(
l
)
=
(
θ
(
l
+
1
)
)
T
⋅
δ
(
l
+
1
)
\frac{\partial J}{\partial a^{(l)}}=\left(\theta^{(l+1)}\right)^T\cdot \delta^{(l+1)}
∂a(l)∂J=(θ(l+1))T⋅δ(l+1)
这里用了Lemma 1的第一个,然后根据Lemma 2,
a
(
l
)
=
g
(
z
(
l
)
)
a^{(l)}=g(z^{(l)})
a(l)=g(z(l)),因此:
∂
J
∂
z
(
l
)
=
(
θ
(
l
+
1
)
)
T
⋅
δ
(
l
+
1
)
⊙
g
′
(
z
(
l
)
)
\frac{\partial J}{\partial z^{(l)}}=\left(\theta^{(l+1)}\right)^T\cdot \delta^{(l+1)}\odot g'(z^{(l)})
∂z(l)∂J=(θ(l+1))T⋅δ(l+1)⊙g′(z(l))
这样就完成了推导。
我认为这种方式比
∂
z
(
l
+
1
)
∂
z
(
l
)
\frac{\partial z^{(l+1)}}{\partial z^{(l)}}
∂z(l)∂z(l+1)这种写法要清晰明白很多,因为矩阵对矩阵的导数一定是得每个元素都要求导。这样就出来一个mn x mn矩阵了,但是目前这种方式,就明白清晰了很多。
另外如果吧bias一项放入
θ
\theta
θ里边去,然后
a
(
l
)
a^{(l)}
a(l)不上一行1,也是可以的,就直接用:
∂
J
∂
θ
(
l
)
=
δ
(
l
)
⋅
(
a
(
l
−
1
)
)
T
\frac{\partial J}{\partial \theta^{(l)}}=\delta^{(l)}\cdot (a^{(l-1)})^T
∂θ(l)∂J=δ(l)⋅(a(l−1))T即可。