【PyTorch 理论】交叉熵损失函数的理解


说起交叉熵损失函数「Cross Entropy Loss」,马上就能说出它的二分类公式:
L = − [ y l o g y ^ + ( 1 − y ) l o g ( 1 − y ^ ) ] L=−[ylog \hat{y}+(1−y)log (1−\hat{y})] L=[ylogy^+(1y)log(1y^)]
其中 y为真实标签, y ^ \hat{y} y^为预测概率。
但是它是怎么来的?为什么它能表征真实样本标签和预测概率之间的差值?

一、交叉熵损失函数的数学原理

我们知道,在二分类问题模型,真实样本的标签为 [0,1],分别表示负类和正类。模型的最后通常会经过一个 Sigmoid 函数,输出一个概率值,这个概率值反映了预测为正类的可能性:概率越大,可能性越大。

Sigmoid 函数的表达式和图形如下所示:

g ( s ) = 1 1 + e − s g(s) = \frac{1}{ 1+e^{-s}} g(s)=1+es1

在这里插入图片描述

其中 s 是模型上一层的输出,Sigmoid 函数有这样的特点:s = 0 时,g(s) = 0.5;s >> 0 时, g ≈ 1; s<<0时,g ≈ 0。显然,g(s) 将前一级的线性输出映射到 [0,1] 之间的数值概率上。这里的 g(s) 就是交叉熵公式中的模型预测输出 。

预测输出 y ^ \hat{y} y^即 Sigmoid 函数的输出表征了当前样本标签为 1 的概率:

y ^ = P ( y = 1 ∣ x ) \hat{y} =P(y=1|x) y^=P(y=1x)

很明显,当前样本标签为 0 的概率就可以表达成:

1 − y ^ = P ( y = 0 ∣ x ) 1-\hat{y} =P(y=0|x) 1y^=P(y=0x)

再从极大似然性的角度出发,把上面两种情况整合到一起:
P ( y ∣ x ) = y ^ y ∗ ( 1 − y ^ ) 1 − y P(y|x)=\hat{y} ^y*(1-\hat{y})^{1-y} P(yx)=y^y(1y^)1y

不懂极大似然估计也没关系。我们可以这么来看:

  1. 当真实样本标签 y = 0 时,上面式子第一项就为 1,概率等式转化为:
    P ( y = 0 ∣ x ) = 1 − y ^ P(y=0|x)=1-\hat{y} P(y=0x)=1y^
  2. 当真实样本标签 y = 1 时,上面式子第二项就为 1,概率等式转化为:
    P ( y = 1 ∣ x ) = y ^ P(y=1|x)=\hat{y} P(y=1x)=y^

两种情况下概率表达式跟之前的完全一致,只不过我们把两种情况整合在一起了,重点看一下整合之后的概率表达式。
我们希望的是概率 P(y|x) 越大越好。首先,我们对 P(y|x) 引入 log 函数,因为 log 运算并不会影响函数本身的单调性。则有:
在这里插入图片描述
我们希望 log P(y|x) 越大越好,反过来,只要 log P(y|x) 的负值 -log P(y|x) 越小就行了。那我们就可以引入损失函数,且令 Loss = -log P(y|x)即可。则得到损失函数为:

L = − [ y l o g y ^ + ( 1 − y ) l o g ( 1 − y ^ ) ] L=−[ylog \hat{y}+(1−y)log (1−\hat{y})] L=[ylogy^+(1y)log(1y^)]

以上,我们已经推导出了单个样本的损失函数,是如果是计算 N 个样本的总的损失函数,只要将 N 个 Loss 叠加起来就可以了:
在这里插入图片描述
这样,我们已经完整地实现了交叉熵损失函数的推导过程。

二. 交叉熵损失函数的直观理解

接下来,我们从图形的角度,分析交叉熵函数,加深大家的理解。
首先,还是写出单个样本的交叉熵损失函数:

L = − [ y l o g y ^ + ( 1 − y ) l o g ( 1 − y ^ ) ] L=−[ylog \hat{y}+(1−y)log (1−\hat{y})] L=[ylogy^+(1y)log(1y^)]

我们知道,当 y = 1 时:
L = − l o g y ^ L=−log \hat{y} L=logy^

这时候,L 与预测输出的关系如下图所示:
在这里插入图片描述
横坐标是预测输出,纵坐标是交叉熵损失函数 L。显然,预测输出越接近真实样本标签 1,损失函数 L 越小;预测输出越接近 0,L 越大。因此,函数的变化趋势完全符合实际需要的情况。
当 y = 0 时:
L = − l o g ( 1 − y ^ ) L=−log (1−\hat{y}) L=log(1y^)

这时候,L 与预测输出的关系如下图所示:
在这里插入图片描述
同样,预测输出越接近真实样本标签 0,损失函数 L 越小;预测函数越接近 1,L 越大。函数的变化趋势也完全符合实际需要的情况。

从上面两种图,可以帮助我们对交叉熵损失函数有更直观的理解。无论真实样本标签 y 是 0 还是 1,L 都表征了预测输出与 y 的差距。

另外,重点提一点的是,从图形中我们可以发现:预测输出与 y 差得越多,L 的值越大,也就是说对当前模型的 “ 惩罚 ” 越大,而且是非线性增大,是一种类似指数增长的级别。这是由 log 函数本身的特性所决定的。这样的好处是模型会倾向于让预测输出更接近真实样本标签 y。

  • 20
    点赞
  • 32
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值