在深度学习网络训练中,交叉熵损失是一种经常使用的损失函数,这篇文章里我们来推导一下交叉熵损失关于网络输出z的导数,由于二分类是多分类的特殊情况,我们直接介绍多分类的推导过程。
一、Softmax交叉熵损失求导
基于softmax的多分类交叉熵公式为
L S C E = − ∑ j = 1 C y j log ( p j ) L_{S C E}=-\sum_{j=1}^{C} y_{j} \log \left(p_{j}\right) LSCE=−j=1∑Cyjlog(pj)
其中 C C C表示类别总数,包含背景类别, p j p_j pj通过 S o f t m a x ( z j ) Softmax(z_j) Softmax(zj)计算得到, z j z_j zj是网络的输出。 y y y是真实标签,通常由one-hot形式编码,单独一个样本的标签如下:
y
j
=
{
1
if
j
=
c
0
otherwise
y_{j}=\left\{\begin{array}{ll} 1 & \text { if } j=c \\ 0 & \text { otherwise } \end{array}\right.
yj={10 if j=c otherwise
c
c
c表示这个样本属于
c
c
c类。
我们拿1个属于c类的样本来举例,网络输出为z,因为总共有
C
C
C类,所以网络有
C
C
C个
z
z
z值,
{
z
1
,
z
2
,
.
.
.
z
c
,
.
.
.
z
C
−
1
,
z
C
}
\{z_1,z_2,...z_c,...z_{C-1},z_C\}
{z1,z2,...zc,...zC−1,zC},然后经过Softmax激活得到
C
C
C个和为1的概率值
p
p
p,该样本的真实标签
y
y
y只有
y
c
=
1
y_c=1
yc=1,其余都为0,每一类的损失是:-1x标签xlog(概率值),最后求和得到总损失。
可以知道,
c
c
c类样本的标签编码中除了
y
c
y_c
yc=1外,其他值
y
j
y_j
yj都为0,所以这个样本对应的其他类的交叉熵都为0,总损失可以化简为:
L
S
C
E
=
−
log
(
p
c
)
p
c
=
e
z
c
∑
j
=
1
C
e
z
j
L_{S C E}=-\log \left(p_{c}\right)\\ p_c=\frac{e^{z_c} }{\sum_{j=1}^{C}{e^{z_j} }}
LSCE=−log(pc)pc=∑j=1Cezjezc
下面我们来计算一下损失
L
S
C
E
L_{SCE}
LSCE对每个
z
j
z_j
zj的导数。当
y
j
=
0
时
y_j=0时
yj=0时,该类对应的损失为0,求导时无用,但是由于激活函数是Softmax,计算
p
c
p_c
pc时
z
j
z_j
zj被用到(分母),所以不管
y
j
y_j
yj是否为0,对
z
j
z_j
zj求导时,都需要考虑
c
c
c类对应的概率值
p
c
p_c
pc。
对
z
j
z_j
zj求导需要用到链式求导法则,即
∂
L
S
C
E
∂
z
j
=
∂
L
S
C
E
∂
p
c
×
∂
p
c
∂
z
j
=
∂
(
−
l
o
g
(
p
c
)
)
∂
p
c
×
∂
(
e
z
c
∑
j
=
1
C
e
z
j
)
∂
z
j
=
−
1
p
c
×
∂
e
z
c
∂
z
j
×
(
∑
j
=
1
C
e
z
j
)
−
e
z
c
×
e
z
j
(
∑
j
=
1
C
e
z
j
)
2
=
−
∑
j
=
1
C
e
z
j
e
z
c
×
∂
e
z
c
∂
z
j
×
(
∑
j
=
1
C
e
z
j
)
−
e
z
c
×
e
z
j
(
∑
j
=
1
C
e
z
j
)
2
\begin{array}{lcl} \frac{\partial L_{SCE}}{\partial z_j} &=&\frac{\partial L_{SCE}}{\partial p_c}&\times &\frac{\partial p_c}{\partial z_j} \\[4mm] &=&\frac{\partial (-log(p_c))}{\partial p_c} &\times&\frac{\partial(\frac{e^{z_c}}{\sum_{j=1}^{C}e^{z_j}})}{\partial z_j}\\[4mm] &=&-\frac{1}{p_c} &\times &\frac{{\color{red}\frac{\partial e^{z_c}}{\partial z_j}}\times (\sum_{j=1}^{C}e^{z_j})-e^{z_c}\times e^{z_j}}{(\sum_{j=1}^{C}e^{z_j})^2} \\[4mm]&=&-\frac{\sum_{j=1}^{C}e^{z_j}}{e^{z_c}} &\times &\frac{{\color{red}\frac{\partial e^{z_c}}{\partial z_j}}\times (\sum_{j=1}^{C}e^{z_j})-e^{z_c}\times e^{z_j}}{(\sum_{j=1}^{C}e^{z_j})^2} \end{array}
∂zj∂LSCE====∂pc∂LSCE∂pc∂(−log(pc))−pc1−ezc∑j=1Cezj××××∂zj∂pc∂zj∂(∑j=1Cezjezc)(∑j=1Cezj)2∂zj∂ezc×(∑j=1Cezj)−ezc×ezj(∑j=1Cezj)2∂zj∂ezc×(∑j=1Cezj)−ezc×ezj
当
j
=
c
j=c
j=c时,
∂
e
z
c
∂
z
j
=
∂
e
z
j
∂
z
j
=
e
z
j
\begin{array}{lcl} \color{red} \frac {\partial e^{z_c}} {\partial z_j} &=&\frac{\partial e^{z_j}} {\partial z_j}\\[4mm] &=&e^{z_j} \end{array}
∂zj∂ezc==∂zj∂ezjezj
代入
∂
L
S
C
E
∂
z
j
\frac{\partial L_{SCE}}{\partial z_j}
∂zj∂LSCE得
∂
L
S
C
E
∂
z
j
=
−
∑
j
=
1
C
e
z
j
e
z
c
×
∂
e
z
c
∂
z
j
×
(
∑
j
=
1
C
e
z
j
)
−
e
z
c
×
e
z
j
(
∑
j
=
1
C
e
z
j
)
2
=
−
∑
j
=
1
C
e
z
j
e
z
j
×
e
z
j
×
(
∑
j
=
1
C
e
z
j
)
−
e
z
j
×
e
z
j
(
∑
j
=
1
C
e
z
j
)
2
=
−
(
∑
j
=
1
C
e
z
j
)
−
e
z
j
∑
j
=
1
C
e
z
j
=
−
(
1
−
e
z
j
∑
j
=
1
C
e
z
j
)
=
p
j
−
1
\begin{array}{lcl} \frac{\partial L_{SCE}}{\partial z_j}&=&-\frac{\sum_{j=1}^{C}e^{z_j}}{e^{z_c}} &\times&\frac{ {\color {red} \frac{\partial e^{z_c}}{\partial z_j}} \times(\sum_{j=1}^{C}e^{z_j})-e^{z_c}\times e^{z_j}}{(\sum_{j=1}^{C}e^{z_j})^2} \\[4mm]&=&-\frac{\sum_{j=1}^{C}e^{z_j}}{e^{z_j}} &\times&\frac{{\color {red}e^{z_j}}\times(\sum_{j=1}^{C}e^{z_j})-e^{z_j}\times e^{z_j}}{(\sum_{j=1}^{C}e^{z_j})^2} \\[4mm]&=&-\frac{(\sum_{j=1}^{C}e^{z_j})-e^{z_j}}{\sum_{j=1}^{C}e^{z_j}}\\[4mm] &=&-(1-\frac{e^{z_j}}{\sum_{j=1}^{C}e^{z_j}})\\[4mm] &=&p_j-1 \end{array}
∂zj∂LSCE=====−ezc∑j=1Cezj−ezj∑j=1Cezj−∑j=1Cezj(∑j=1Cezj)−ezj−(1−∑j=1Cezjezj)pj−1××(∑j=1Cezj)2∂zj∂ezc×(∑j=1Cezj)−ezc×ezj(∑j=1Cezj)2ezj×(∑j=1Cezj)−ezj×ezj
当
j
≠
c
j \neq c
j=c时
∂
e
z
c
∂
z
j
=
0
\begin{array}{lcl} \color{red} \frac{\partial e^{z_c}} {\partial z_j} &=&0 \end{array}
∂zj∂ezc=0
代入
∂
L
S
C
E
∂
z
j
\frac{\partial L_{SCE}}{\partial z_j}
∂zj∂LSCE,
∂
L
S
C
E
∂
z
j
=
−
∑
j
=
1
C
e
z
j
e
z
c
×
∂
e
z
c
∂
z
j
×
(
∑
j
=
1
C
e
z
j
)
−
e
z
c
×
e
z
j
(
∑
j
=
1
C
e
z
j
)
2
=
−
∑
j
=
1
C
e
z
j
e
z
c
×
0
×
(
∑
j
=
1
C
e
z
j
)
−
e
z
c
×
e
z
j
(
∑
j
=
1
C
e
z
j
)
2
=
−
−
e
z
j
∑
j
=
1
C
e
z
j
=
p
j
\begin{array}{lcl} \frac{\partial L_{SCE}}{\partial z_j}&=&-\frac{\sum_{j=1}^{C}e^{z_j}}{e^{z_c}} &\times &\frac{{\color{red}\frac{\partial e^{z_c}}{\partial z_j}}\times (\sum_{j=1}^{C}e^{z_j})-e^{z_c}\times e^{z_j}}{(\sum_{j=1}^{C}e^{z_j})^2} \\[4mm]&=&-\frac{\sum_{j=1}^{C}e^{z_j}}{e^{z_c}} &\times &\frac{{\color{red}0} \times (\sum_{j=1}^{C}e^{z_j})-e^{z_c}\times e^{z_j}}{(\sum_{j=1}^{C}e^{z_j})^2} \\[4mm]&=&-\frac{-e^{z_j}}{\sum_{j=1}^{C}e^{z_j}}\\[4mm] &=&p_j \end{array}
∂zj∂LSCE====−ezc∑j=1Cezj−ezc∑j=1Cezj−∑j=1Cezj−ezjpj××(∑j=1Cezj)2∂zj∂ezc×(∑j=1Cezj)−ezc×ezj(∑j=1Cezj)20×(∑j=1Cezj)−ezc×ezj
所以:
∂
L
S
C
E
∂
z
j
=
{
p
j
−
1
if
j
=
c
p
j
j
≠
c
\frac{\partial L_{SCE}}{\partial z_{j}}=\left\{\begin{array}{ll} p_{j}-1 & \text { if } j=c \\ p_{j} & { j \ne c } \end{array}\right.
∂zj∂LSCE={pj−1pj if j=cj=c
二、Sigmoid交叉熵损失求导
sigmoid一般是用在二分类问题中,二分类时,网络只有一个输出值,经过sigmoid函数得到该样本是正样本的概率值。损失函数如下:
L
=
−
y
∗
l
o
g
p
−
(
1
−
y
)
∗
l
o
g
(
1
−
p
)
L=-y*logp-(1-y)*log(1-p)
L=−y∗logp−(1−y)∗log(1−p)
使用Sigmoid函数做多分类时,相当于把每一个类看成是独立的二分类问题,类之间不会相互影响。真实标签
y
j
y_j
yj只表示j类的二分类情况。
基于sigmoid的多分类交叉熵公式如下:
L
B
C
E
=
−
∑
j
C
log
(
p
j
^
)
L_{B C E}=-\sum_{j}^{C} \log \left(\hat{p_{j}}\right)
LBCE=−j∑Clog(pj^)
p
j
^
=
{
p
j
if
y
j
=
1
1
−
p
j
otherwise
\hat{p_{j}}=\left\{\begin{array}{ll} p_{j} & \text { if } y_{j}=1 \\ 1-p_{j} & \text { otherwise } \end{array}\right.
pj^={pj1−pj if yj=1 otherwise
其中
p
j
p_j
pj通过
σ
(
z
j
)
\sigma\left(z_{j}\right)
σ(zj)计算得到,即sigmoid函数,表达式如下:
p
j
=
1
1
+
e
−
z
j
p_j=\frac{1}{1+e^{-z_j}}
pj=1+e−zj1
sigmoid函数的导数如下:
∂
p
j
∂
z
j
=
∂
(
1
)
∂
z
j
×
(
1
+
e
−
z
j
)
−
1
×
∂
(
1
+
e
−
z
j
)
∂
z
j
(
1
+
e
−
z
j
)
2
=
−
e
−
z
j
×
(
−
1
)
(
1
+
e
−
z
j
)
2
=
1
+
e
−
z
j
−
1
(
1
+
e
−
z
j
)
2
=
1
(
1
+
e
−
z
j
)
−
1
(
1
+
e
−
z
j
)
2
=
p
j
−
p
j
2
=
p
j
(
1
−
p
j
)
\begin{array}{lcl} \frac{\partial p_j}{\partial z_j}&=&\frac{\frac{\partial(1)}{\partial z_j}\times (1+e^{-z_j})-1\times \frac{\partial(1+e^{-z_j})}{\partial z_j}}{(1+e^{-z_j})^2}\\[4mm] &=&\frac{-e^{-z_j}\times (-1)}{{(1+e^{-z_j})^2}}\\[4mm] &=&\frac{1+e^{-z_j}-1}{{(1+e^{-z_j})^2}}\\[4mm] &=&\frac{1}{{(1+e^{-z_j})}}-\frac{1}{{(1+e^{-z_j})^2}}\\[4mm] &=&p_j-p_j^{2}\\[4mm] &=&p_j(1-p_j) \end{array}
∂zj∂pj======(1+e−zj)2∂zj∂(1)×(1+e−zj)−1×∂zj∂(1+e−zj)(1+e−zj)2−e−zj×(−1)(1+e−zj)21+e−zj−1(1+e−zj)1−(1+e−zj)21pj−pj2pj(1−pj)
我们拿1个属于c类的样本来举例,网络输出为z,因为总共有
C
C
C类,所以网络有
C
C
C个
z
z
z值,
{
z
1
,
z
2
,
.
.
.
z
c
,
.
.
.
z
C
−
1
,
z
C
}
\{z_1,z_2,...z_c,...z_{C-1},z_C\}
{z1,z2,...zc,...zC−1,zC},然后经过sigmoid激活得到
C
C
C个独立的概率值
p
p
p,该样本的真实标签
y
y
y只有
y
c
=
1
y_c=1
yc=1,其余都为0。每一类都是一个单独的二分类问题,通过二分类交叉熵来计算损失,最后把所有类的损失相加。
现在我们计算损失
L
B
C
E
L_{BCE}
LBCE关于网络输出
z
z
z的导数
∂
L
∂
z
j
\frac{\partial L}{\partial z_j}
∂zj∂L,这里需要用到链式法则,在计算Loss对
z
j
z_j
zj的导数时,只需要考虑该类对应的
p
j
p_j
pj即可,因为其他类的概率值跟
z
j
z_j
zj没有关系。
∂
L
B
C
E
∂
z
j
=
∂
L
B
C
E
∂
p
j
×
∂
p
j
∂
z
j
=
∂
(
−
l
o
g
(
p
j
^
)
)
∂
p
j
×
p
j
×
(
1
−
p
j
)
\begin{array}{c} \frac{\partial L_{BCE}}{\partial z_j}&=&\frac{\partial L_{BCE}}{\partial {p_j}}&\times& \frac{\partial {p_j}}{\partial z_j} \\[4mm]&=&{\color{red}\frac{\partial (-log(\hat{pj} ))}{\partial p_j}}& \times &p_j \times(1-pj) \end{array}
∂zj∂LBCE==∂pj∂LBCE∂pj∂(−log(pj^))××∂zj∂pjpj×(1−pj)
当
y
j
=
1
y_j=1
yj=1时,
p
j
^
=
p
j
\hat{p_j}=p_j
pj^=pj:
∂
L
B
C
E
∂
z
j
=
∂
L
B
C
E
∂
p
j
×
∂
p
j
∂
z
j
=
∂
(
−
l
o
g
(
p
j
^
)
)
∂
p
j
×
p
j
×
(
1
−
p
j
)
=
∂
(
−
l
o
g
(
p
j
)
)
∂
p
j
×
p
j
×
(
1
−
p
j
)
=
−
1
p
j
×
p
j
×
(
1
−
p
j
)
=
p
j
−
1
\begin{array}{c} \frac{\partial L_{BCE}}{\partial z_j}&=&\frac{\partial L_{BCE}}{\partial {p_j}}&\times& \frac{\partial {p_j}}{\partial z_j} \\[4mm]&=&{\color{red}\frac{\partial (-log(\hat{pj} ))}{\partial p_j}}& \times &p_j \times(1-pj) \\[4mm]&=&{\color{red}\frac{\partial (-log({pj} ))}{\partial p_j}}& \times &p_j \times(1-pj) \\[4mm]&=&{-\color{red}\frac{1}{p_j} }& \times &p_j \times(1-pj) \\[4mm]&=&p_j-1 \end{array}
∂zj∂LBCE=====∂pj∂LBCE∂pj∂(−log(pj^))∂pj∂(−log(pj))−pj1pj−1××××∂zj∂pjpj×(1−pj)pj×(1−pj)pj×(1−pj)
当
y
j
=
0
y_j=0
yj=0时,
p
j
^
=
1
−
p
j
\hat{p_j}=1-p_j
pj^=1−pj:
∂
L
B
C
E
∂
z
j
=
∂
L
B
C
E
∂
p
j
×
∂
p
j
∂
z
j
=
∂
(
−
l
o
g
(
p
j
^
)
)
∂
p
j
×
p
j
×
(
1
−
p
j
)
=
∂
(
−
l
o
g
(
1
−
p
j
)
)
∂
p
j
×
p
j
×
(
1
−
p
j
)
=
−
1
1
−
p
j
×
−
1
×
p
j
×
(
1
−
p
j
)
=
p
j
\begin{array}{c} \frac{\partial L_{BCE}}{\partial z_j}&=&\frac{\partial L_{BCE}}{\partial {p_j}}&\times& \frac{\partial {p_j}}{\partial z_j} \\[4mm]&=&{\color{red}\frac{\partial (-log(\hat{pj} ))}{\partial p_j}}& \times &p_j \times(1-pj) \\[4mm]&=&{\color{red}\frac{\partial (-log({1-pj} ))}{\partial p_j}}& \times &p_j \times(1-pj) \\[4mm]&=&{-\color{red}\frac{1}{1-p_j} \times -1} & \times &p_j \times(1-pj) \\[4mm]&=&p_j \end{array}
∂zj∂LBCE=====∂pj∂LBCE∂pj∂(−log(pj^))∂pj∂(−log(1−pj))−1−pj1×−1pj××××∂zj∂pjpj×(1−pj)pj×(1−pj)pj×(1−pj)
所以
∂
L
B
C
E
∂
z
j
=
{
p
j
−
1
if
y
j
=
1
p
j
otherwise
\frac{\partial L_{BCE}}{\partial z_{j}}=\left\{\begin{array}{ll} p_{j}-1 & \text { if } y_{j}=1 \\ p_{j} & \text { otherwise } \end{array}\right.
∂zj∂LBCE={pj−1pj if yj=1 otherwise
三、总结
不管是使用sigmoid还是softmax作为最后的分类器,损失函数关于网络输出z的导数的形式是一样的。
∂
L
o
s
s
∂
z
j
=
{
p
j
−
1
if
y
j
=
1
p
j
otherwise
\frac{\partial Loss}{\partial z_{j}}=\left\{\begin{array}{ll} p_{j}-1 & \text { if } y_{j}=1 \\ p_{j} & \text { otherwise } \end{array}\right.
∂zj∂Loss={pj−1pj if yj=1 otherwise