交叉熵损失导数推理

在深度学习网络训练中,交叉熵损失是一种经常使用的损失函数,这篇文章里我们来推导一下交叉熵损失关于网络输出z的导数,由于二分类是多分类的特殊情况,我们直接介绍多分类的推导过程。

一、Softmax交叉熵损失求导

基于softmax的多分类交叉熵公式为

L S C E = − ∑ j = 1 C y j log ⁡ ( p j ) L_{S C E}=-\sum_{j=1}^{C} y_{j} \log \left(p_{j}\right) LSCE=j=1Cyjlog(pj)

其中 C C C表示类别总数,包含背景类别, p j p_j pj通过 S o f t m a x ( z j ) Softmax(z_j) Softmax(zj)计算得到, z j z_j zj是网络的输出。 y y y是真实标签,通常由one-hot形式编码,单独一个样本的标签如下:

y j = { 1  if  j = c 0  otherwise  y_{j}=\left\{\begin{array}{ll} 1 & \text { if } j=c \\ 0 & \text { otherwise } \end{array}\right. yj={10 if j=c otherwise 
c c c表示这个样本属于 c c c类。
我们拿1个属于c类的样本来举例,网络输出为z,因为总共有 C C C类,所以网络有 C C C z z z值, { z 1 , z 2 , . . . z c , . . . z C − 1 , z C } \{z_1,z_2,...z_c,...z_{C-1},z_C\} {z1,z2,...zc,...zC1,zC},然后经过Softmax激活得到 C C C个和为1的概率值 p p p,该样本的真实标签 y y y只有 y c = 1 y_c=1 yc=1,其余都为0,每一类的损失是:-1x标签xlog(概率值),最后求和得到总损失。
在这里插入图片描述
可以知道, c c c类样本的标签编码中除了 y c y_c yc=1外,其他值 y j y_j yj都为0,所以这个样本对应的其他类的交叉熵都为0,总损失可以化简为:
L S C E = − log ⁡ ( p c ) p c = e z c ∑ j = 1 C e z j L_{S C E}=-\log \left(p_{c}\right)\\ p_c=\frac{e^{z_c} }{\sum_{j=1}^{C}{e^{z_j} }} LSCE=log(pc)pc=j=1Cezjezc
下面我们来计算一下损失 L S C E L_{SCE} LSCE对每个 z j z_j zj的导数。当 y j = 0 时 y_j=0时 yj=0,该类对应的损失为0,求导时无用,但是由于激活函数是Softmax,计算 p c p_c pc z j z_j zj被用到(分母),所以不管 y j y_j yj是否为0,对 z j z_j zj求导时,都需要考虑 c c c类对应的概率值 p c p_c pc

z j z_j zj求导需要用到链式求导法则,即
∂ L S C E ∂ z j = ∂ L S C E ∂ p c × ∂ p c ∂ z j = ∂ ( − l o g ( p c ) ) ∂ p c × ∂ ( e z c ∑ j = 1 C e z j ) ∂ z j = − 1 p c × ∂ e z c ∂ z j × ( ∑ j = 1 C e z j ) − e z c × e z j ( ∑ j = 1 C e z j ) 2 = − ∑ j = 1 C e z j e z c × ∂ e z c ∂ z j × ( ∑ j = 1 C e z j ) − e z c × e z j ( ∑ j = 1 C e z j ) 2 \begin{array}{lcl} \frac{\partial L_{SCE}}{\partial z_j} &=&\frac{\partial L_{SCE}}{\partial p_c}&\times &\frac{\partial p_c}{\partial z_j} \\[4mm] &=&\frac{\partial (-log(p_c))}{\partial p_c} &\times&\frac{\partial(\frac{e^{z_c}}{\sum_{j=1}^{C}e^{z_j}})}{\partial z_j}\\[4mm] &=&-\frac{1}{p_c} &\times &\frac{{\color{red}\frac{\partial e^{z_c}}{\partial z_j}}\times (\sum_{j=1}^{C}e^{z_j})-e^{z_c}\times e^{z_j}}{(\sum_{j=1}^{C}e^{z_j})^2} \\[4mm]&=&-\frac{\sum_{j=1}^{C}e^{z_j}}{e^{z_c}} &\times &\frac{{\color{red}\frac{\partial e^{z_c}}{\partial z_j}}\times (\sum_{j=1}^{C}e^{z_j})-e^{z_c}\times e^{z_j}}{(\sum_{j=1}^{C}e^{z_j})^2} \end{array} zjLSCE====pcLSCEpc(log(pc))pc1ezcj=1Cezj××××zjpczj(j=1Cezjezc)(j=1Cezj)2zjezc×(j=1Cezj)ezc×ezj(j=1Cezj)2zjezc×(j=1Cezj)ezc×ezj

j = c j=c j=c时,
∂ e z c ∂ z j = ∂ e z j ∂ z j = e z j \begin{array}{lcl} \color{red} \frac {\partial e^{z_c}} {\partial z_j} &=&\frac{\partial e^{z_j}} {\partial z_j}\\[4mm] &=&e^{z_j} \end{array} zjezc==zjezjezj
代入 ∂ L S C E ∂ z j \frac{\partial L_{SCE}}{\partial z_j} zjLSCE
∂ L S C E ∂ z j = − ∑ j = 1 C e z j e z c × ∂ e z c ∂ z j × ( ∑ j = 1 C e z j ) − e z c × e z j ( ∑ j = 1 C e z j ) 2 = − ∑ j = 1 C e z j e z j × e z j × ( ∑ j = 1 C e z j ) − e z j × e z j ( ∑ j = 1 C e z j ) 2 = − ( ∑ j = 1 C e z j ) − e z j ∑ j = 1 C e z j = − ( 1 − e z j ∑ j = 1 C e z j ) = p j − 1 \begin{array}{lcl} \frac{\partial L_{SCE}}{\partial z_j}&=&-\frac{\sum_{j=1}^{C}e^{z_j}}{e^{z_c}} &\times&\frac{ {\color {red} \frac{\partial e^{z_c}}{\partial z_j}} \times(\sum_{j=1}^{C}e^{z_j})-e^{z_c}\times e^{z_j}}{(\sum_{j=1}^{C}e^{z_j})^2} \\[4mm]&=&-\frac{\sum_{j=1}^{C}e^{z_j}}{e^{z_j}} &\times&\frac{{\color {red}e^{z_j}}\times(\sum_{j=1}^{C}e^{z_j})-e^{z_j}\times e^{z_j}}{(\sum_{j=1}^{C}e^{z_j})^2} \\[4mm]&=&-\frac{(\sum_{j=1}^{C}e^{z_j})-e^{z_j}}{\sum_{j=1}^{C}e^{z_j}}\\[4mm] &=&-(1-\frac{e^{z_j}}{\sum_{j=1}^{C}e^{z_j}})\\[4mm] &=&p_j-1 \end{array} zjLSCE=====ezcj=1Cezjezjj=1Cezjj=1Cezj(j=1Cezj)ezj(1j=1Cezjezj)pj1××(j=1Cezj)2zjezc×(j=1Cezj)ezc×ezj(j=1Cezj)2ezj×(j=1Cezj)ezj×ezj

j ≠ c j \neq c j=c
∂ e z c ∂ z j = 0 \begin{array}{lcl} \color{red} \frac{\partial e^{z_c}} {\partial z_j} &=&0 \end{array} zjezc=0
代入 ∂ L S C E ∂ z j \frac{\partial L_{SCE}}{\partial z_j} zjLSCE
∂ L S C E ∂ z j = − ∑ j = 1 C e z j e z c × ∂ e z c ∂ z j × ( ∑ j = 1 C e z j ) − e z c × e z j ( ∑ j = 1 C e z j ) 2 = − ∑ j = 1 C e z j e z c × 0 × ( ∑ j = 1 C e z j ) − e z c × e z j ( ∑ j = 1 C e z j ) 2 = − − e z j ∑ j = 1 C e z j = p j \begin{array}{lcl} \frac{\partial L_{SCE}}{\partial z_j}&=&-\frac{\sum_{j=1}^{C}e^{z_j}}{e^{z_c}} &\times &\frac{{\color{red}\frac{\partial e^{z_c}}{\partial z_j}}\times (\sum_{j=1}^{C}e^{z_j})-e^{z_c}\times e^{z_j}}{(\sum_{j=1}^{C}e^{z_j})^2} \\[4mm]&=&-\frac{\sum_{j=1}^{C}e^{z_j}}{e^{z_c}} &\times &\frac{{\color{red}0} \times (\sum_{j=1}^{C}e^{z_j})-e^{z_c}\times e^{z_j}}{(\sum_{j=1}^{C}e^{z_j})^2} \\[4mm]&=&-\frac{-e^{z_j}}{\sum_{j=1}^{C}e^{z_j}}\\[4mm] &=&p_j \end{array} zjLSCE====ezcj=1Cezjezcj=1Cezjj=1Cezjezjpj××(j=1Cezj)2zjezc×(j=1Cezj)ezc×ezj(j=1Cezj)20×(j=1Cezj)ezc×ezj

所以:
∂ L S C E ∂ z j = { p j − 1  if  j = c p j j ≠ c \frac{\partial L_{SCE}}{\partial z_{j}}=\left\{\begin{array}{ll} p_{j}-1 & \text { if } j=c \\ p_{j} & { j \ne c } \end{array}\right. zjLSCE={pj1pj if j=cj=c
在这里插入图片描述

二、Sigmoid交叉熵损失求导

sigmoid一般是用在二分类问题中,二分类时,网络只有一个输出值,经过sigmoid函数得到该样本是正样本的概率值。损失函数如下:
L = − y ∗ l o g p − ( 1 − y ) ∗ l o g ( 1 − p ) L=-y*logp-(1-y)*log(1-p) L=ylogp(1y)log(1p)
使用Sigmoid函数做多分类时,相当于把每一个类看成是独立的二分类问题,类之间不会相互影响。真实标签 y j y_j yj只表示j类的二分类情况。
基于sigmoid的多分类交叉熵公式如下:
L B C E = − ∑ j C log ⁡ ( p j ^ ) L_{B C E}=-\sum_{j}^{C} \log \left(\hat{p_{j}}\right) LBCE=jClog(pj^)

p j ^ = { p j  if  y j = 1 1 − p j  otherwise  \hat{p_{j}}=\left\{\begin{array}{ll} p_{j} & \text { if } y_{j}=1 \\ 1-p_{j} & \text { otherwise } \end{array}\right. pj^={pj1pj if yj=1 otherwise 
其中 p j p_j pj通过 σ ( z j ) \sigma\left(z_{j}\right) σ(zj)计算得到,即sigmoid函数,表达式如下:
p j = 1 1 + e − z j p_j=\frac{1}{1+e^{-z_j}} pj=1+ezj1
sigmoid函数的导数如下:
∂ p j ∂ z j = ∂ ( 1 ) ∂ z j × ( 1 + e − z j ) − 1 × ∂ ( 1 + e − z j ) ∂ z j ( 1 + e − z j ) 2 = − e − z j × ( − 1 ) ( 1 + e − z j ) 2 = 1 + e − z j − 1 ( 1 + e − z j ) 2 = 1 ( 1 + e − z j ) − 1 ( 1 + e − z j ) 2 = p j − p j 2 = p j ( 1 − p j ) \begin{array}{lcl} \frac{\partial p_j}{\partial z_j}&=&\frac{\frac{\partial(1)}{\partial z_j}\times (1+e^{-z_j})-1\times \frac{\partial(1+e^{-z_j})}{\partial z_j}}{(1+e^{-z_j})^2}\\[4mm] &=&\frac{-e^{-z_j}\times (-1)}{{(1+e^{-z_j})^2}}\\[4mm] &=&\frac{1+e^{-z_j}-1}{{(1+e^{-z_j})^2}}\\[4mm] &=&\frac{1}{{(1+e^{-z_j})}}-\frac{1}{{(1+e^{-z_j})^2}}\\[4mm] &=&p_j-p_j^{2}\\[4mm] &=&p_j(1-p_j) \end{array} zjpj======(1+ezj)2zj(1)×(1+ezj)1×zj(1+ezj)(1+ezj)2ezj×(1)(1+ezj)21+ezj1(1+ezj)1(1+ezj)21pjpj2pj(1pj)

我们拿1个属于c类的样本来举例,网络输出为z,因为总共有 C C C类,所以网络有 C C C z z z值, { z 1 , z 2 , . . . z c , . . . z C − 1 , z C } \{z_1,z_2,...z_c,...z_{C-1},z_C\} {z1,z2,...zc,...zC1,zC},然后经过sigmoid激活得到 C C C个独立的概率值 p p p,该样本的真实标签 y y y只有 y c = 1 y_c=1 yc=1,其余都为0。每一类都是一个单独的二分类问题,通过二分类交叉熵来计算损失,最后把所有类的损失相加。
在这里插入图片描述
现在我们计算损失 L B C E L_{BCE} LBCE关于网络输出 z z z的导数 ∂ L ∂ z j \frac{\partial L}{\partial z_j} zjL,这里需要用到链式法则,在计算Loss对 z j z_j zj的导数时,只需要考虑该类对应的 p j p_j pj即可,因为其他类的概率值跟 z j z_j zj没有关系。
∂ L B C E ∂ z j = ∂ L B C E ∂ p j × ∂ p j ∂ z j = ∂ ( − l o g ( p j ^ ) ) ∂ p j × p j × ( 1 − p j ) \begin{array}{c} \frac{\partial L_{BCE}}{\partial z_j}&=&\frac{\partial L_{BCE}}{\partial {p_j}}&\times& \frac{\partial {p_j}}{\partial z_j} \\[4mm]&=&{\color{red}\frac{\partial (-log(\hat{pj} ))}{\partial p_j}}& \times &p_j \times(1-pj) \end{array} zjLBCE==pjLBCEpj(log(pj^))××zjpjpj×(1pj)

y j = 1 y_j=1 yj=1时, p j ^ = p j \hat{p_j}=p_j pj^=pj:
∂ L B C E ∂ z j = ∂ L B C E ∂ p j × ∂ p j ∂ z j = ∂ ( − l o g ( p j ^ ) ) ∂ p j × p j × ( 1 − p j ) = ∂ ( − l o g ( p j ) ) ∂ p j × p j × ( 1 − p j ) = − 1 p j × p j × ( 1 − p j ) = p j − 1 \begin{array}{c} \frac{\partial L_{BCE}}{\partial z_j}&=&\frac{\partial L_{BCE}}{\partial {p_j}}&\times& \frac{\partial {p_j}}{\partial z_j} \\[4mm]&=&{\color{red}\frac{\partial (-log(\hat{pj} ))}{\partial p_j}}& \times &p_j \times(1-pj) \\[4mm]&=&{\color{red}\frac{\partial (-log({pj} ))}{\partial p_j}}& \times &p_j \times(1-pj) \\[4mm]&=&{-\color{red}\frac{1}{p_j} }& \times &p_j \times(1-pj) \\[4mm]&=&p_j-1 \end{array} zjLBCE=====pjLBCEpj(log(pj^))pj(log(pj))pj1pj1××××zjpjpj×(1pj)pj×(1pj)pj×(1pj)

y j = 0 y_j=0 yj=0时, p j ^ = 1 − p j \hat{p_j}=1-p_j pj^=1pj:
∂ L B C E ∂ z j = ∂ L B C E ∂ p j × ∂ p j ∂ z j = ∂ ( − l o g ( p j ^ ) ) ∂ p j × p j × ( 1 − p j ) = ∂ ( − l o g ( 1 − p j ) ) ∂ p j × p j × ( 1 − p j ) = − 1 1 − p j × − 1 × p j × ( 1 − p j ) = p j \begin{array}{c} \frac{\partial L_{BCE}}{\partial z_j}&=&\frac{\partial L_{BCE}}{\partial {p_j}}&\times& \frac{\partial {p_j}}{\partial z_j} \\[4mm]&=&{\color{red}\frac{\partial (-log(\hat{pj} ))}{\partial p_j}}& \times &p_j \times(1-pj) \\[4mm]&=&{\color{red}\frac{\partial (-log({1-pj} ))}{\partial p_j}}& \times &p_j \times(1-pj) \\[4mm]&=&{-\color{red}\frac{1}{1-p_j} \times -1} & \times &p_j \times(1-pj) \\[4mm]&=&p_j \end{array} zjLBCE=====pjLBCEpj(log(pj^))pj(log(1pj))1pj1×1pj××××zjpjpj×(1pj)pj×(1pj)pj×(1pj)
所以
∂ L B C E ∂ z j = { p j − 1  if  y j = 1 p j  otherwise  \frac{\partial L_{BCE}}{\partial z_{j}}=\left\{\begin{array}{ll} p_{j}-1 & \text { if } y_{j}=1 \\ p_{j} & \text { otherwise } \end{array}\right. zjLBCE={pj1pj if yj=1 otherwise 

在这里插入图片描述

三、总结

不管是使用sigmoid还是softmax作为最后的分类器,损失函数关于网络输出z的导数的形式是一样的。
∂ L o s s ∂ z j = { p j − 1  if  y j = 1 p j  otherwise  \frac{\partial Loss}{\partial z_{j}}=\left\{\begin{array}{ll} p_{j}-1 & \text { if } y_{j}=1 \\ p_{j} & \text { otherwise } \end{array}\right. zjLoss={pj1pj if yj=1 otherwise 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

GHZhao_GIS_RS

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值