引言
在多分类问题中,一般会把输出结果传入到softmax函数中,得到最终结果。并且用交叉熵作为损失函数。本来就来分析下以交叉熵为损失函数的情况下,softmax如何求导。
对softmax求导
softmax函数为:
y ^ i = e z i ∑ k = 1 K e z k \hat y_i = \frac{e^{z_i}}{\sum_{k=1}^K e^{z_k}} y^i=∑k=1Kezkezi
这里
K
K
K是类别的总数,接下来求
y
^
i
\hat y_i
y^i对某个输出
z
j
z_j
zj的导数,
∂
y
^
i
∂
z
j
=
∂
e
z
i
∑
k
=
1
K
e
z
k
∂
z
j
\frac{\partial \hat y_i}{\partial z_j} = \frac{\partial \frac{e^{z_i}}{\sum_{k=1}^K e^{z_k}}}{\partial z_j}
∂zj∂y^i=∂zj∂∑k=1Kezkezi
这里要分两种情况,分别是 i = j i=j i=j与 i ≠ j i \neq j i=j。当 i = j i=j i=j时, e z i e^{z_i} ezi对 z j z_j zj的导数为 e z i e^{z_i} ezi,否则当 i ≠ j i \neq j i=j时,导数为 0 0 0。
当
i
=
j
i = j
i=j,
∂
y
^
i
∂
z
j
=
e
z
i
⋅
∑
k
=
1
K
e
z
k
−
e
z
i
⋅
e
z
j
(
∑
k
=
1
m
e
z
k
)
2
=
e
z
i
∑
k
=
1
m
e
z
k
−
e
z
i
∑
k
=
1
m
e
z
k
⋅
e
z
j
∑
k
=
1
m
e
z
k
=
y
^
i
−
y
^
i
2
=
y
^
i
(
1
−
y
^
i
)
\begin{aligned} \frac{\partial \hat y_i}{\partial z_j} &= \frac{e^{z_i}\cdot \sum_{k=1}^K e^{z_k} - e^{z_i} \cdot e^{z_j} }{(\sum_{k=1}^m e^{z_k})^2} \\ &= \frac{e^{z_i}}{\sum_{k=1}^m e^{z_k}} - \frac{e^{z_i}}{\sum_{k=1}^m e^{z_k}} \cdot \frac{e^{z_j}}{\sum_{k=1}^m e^{z_k}} \\ &= \hat y_i - \hat y_i^2 = \hat y_i(1 - \hat y_i) \end{aligned}
∂zj∂y^i=(∑k=1mezk)2ezi⋅∑k=1Kezk−ezi⋅ezj=∑k=1mezkezi−∑k=1mezkezi⋅∑k=1mezkezj=y^i−y^i2=y^i(1−y^i)
当
i
≠
j
i \neq j
i=j,
∂
y
^
i
∂
z
j
=
0
⋅
∑
k
=
1
K
e
z
k
−
e
z
i
⋅
e
z
j
(
∑
k
=
1
m
e
z
k
)
2
=
−
e
z
i
∑
k
=
1
m
e
z
k
⋅
e
z
j
∑
k
=
1
m
e
z
k
=
−
y
^
i
y
^
j
\begin{aligned} \frac{\partial \hat y_i}{\partial z_j} &= \frac{0 \cdot \sum_{k=1}^K e^{z_k} - e^{z_i} \cdot e^{z_j}}{(\sum_{k=1}^m e^{z_k})^2} \\ &= - \frac{e^{z_i}}{\sum_{k=1}^m e^{z_k}} \cdot \frac{e^{z_j}}{\sum_{k=1}^m e^{z_k}} \\ &= - \hat y_i \hat y_j \end{aligned}
∂zj∂y^i=(∑k=1mezk)20⋅∑k=1Kezk−ezi⋅ezj=−∑k=1mezkezi⋅∑k=1mezkezj=−y^iy^j
对cross-entropy求导
损失函数 L L L为:
L = − ∑ k y k log y ^ k L = -\sum_k y_k \log \hat y_k L=−k∑yklogy^k
其中 y k y_k yk是真实类别,相当于一个常数,接下来求 L L L对 z j z_j zj的导数
∂ L ∂ z j = ∂ − ( ∑ k y k log y ^ k ) z j = ∂ − ( ∑ k y k log y ^ k ) ∂ y ^ k ∂ y ^ k ∂ z j = − ∑ k y k 1 y ^ k ∂ y ^ k z j = ( − y k ⋅ y ^ k ( 1 − y ^ k ) 1 y ^ k ) k = j − ∑ k ≠ j y k 1 y ^ k ( − y ^ k y ^ j ) = − y j ( 1 − y ^ j ) − ∑ k ≠ j y k ( − y ^ j ) = − y j + y j y ^ j + ∑ k ≠ j y k ( y ^ j ) = − y j + ∑ k y k ( y ^ j ) = − y j + y ^ j = y ^ j − y j \begin{aligned} \frac{\partial L}{\partial z_j} &= \frac{\partial -(\sum_k y_k \log \hat y_k)}{z_j}\\ &= \frac{\partial -(\sum_k y_k \log \hat y_k)}{\partial \hat y_k} \frac{\partial \hat y_k}{\partial z_j} \\ &= -\sum_k y_k \frac{1}{\hat y_k} \frac{\partial \hat y_k}{z_j} \\ &= \left(-y_k \cdot \hat y_k(1 - \hat y_k) \frac{1}{\hat y_k} \right)_{k=j} - \sum_{k \neq j} y_k \frac{1}{\hat y_k} (-\hat y_k \hat y_j) \\ &= - y_j (1 -\hat y_j) - \sum_{k \neq j} y_k (-\hat y_j) \\ &= - y_j + y_j \hat y_j + \sum_{k \neq j} y_k (\hat y_j) \\ &= - y_j + \sum_{k} y_k (\hat y_j) \\ &= - y_j +\hat y_j \\ &= \hat y_j -y_j \end{aligned} ∂zj∂L=zj∂−(∑kyklogy^k)=∂y^k∂−(∑kyklogy^k)∂zj∂y^k=−k∑yky^k1zj∂y^k=(−yk⋅y^k(1−y^k)y^k1)k=j−k=j∑yky^k1(−y^ky^j)=−yj(1−y^j)−k=j∑yk(−y^j)=−yj+yjy^j+k=j∑yk(y^j)=−yj+k∑yk(y^j)=−yj+y^j=y^j−yj
这里用到了 ∑ k y k = 1 \sum_{k} y_k = 1 ∑kyk=1
可以看到,求导结果非常简单,如果不推导都不敢信。