二分类和多分类交叉熵函数区别详解

二分类和多分类交叉熵函数区别详解

写在前面

查了下百度,交叉熵,是度量两个分布间差异的概念。而在我们神经网络中,两个分布也就是y的真实值分布和预测值分布。当两个分布越接近时,其交叉熵值也就越小。

根据上面知识,也就转化为我们需要解决让预测值和真实值尽可能接近的问题,而这正与概率论数理统计中的最大似然分布一脉相承,进而目标转化为确定值的分布和求解最大似然估计问题。

二分类问题

表示分类任务中有两个类别,比如我们想判断一张图片是不是猫。也就是说,训练一个分类器,输入一张图片,用特征向量x表示,输出是不是猫用y=0或1表示,其中1表示是,0表示不是。

这样的问题,我们完全可以用0-1分布来进行表示:

y i y_i yi 1 − y i 1-y_i 1yi
y i ^ \hat{y_i} yi^ 1 − y i ^ 1-\hat{y_i} 1yi^

注:其中yi为真实值, y i ^ \hat{y_i} yi^为预测值,且 y i y_i yi的值为0或1

此时求解最大似然估计过程如下:
L ( y i ^ ) = Π i = 1 n y i ^ y i ( 1 − y i ^ ) 1 − y i L(\hat{y_i})=\Pi_{i=1}^{n}\hat{y_i}^{y_i}(1-\hat{y_i})^{1-y_i} L(yi^)=Πi=1nyi^yi(1yi^)1yi
两边同时取对数
l o g ( L ( y i ^ ) ) = ∑ i = 1 n ( y i l o g ( y i ^ ) + ( 1 − y i ) l o g ( 1 − y i ^ ) ) log(L(\hat{y_i}))=\sum_{i=1}^{n}(y_ilog(\hat{y_i})+(1-y_i)log(1-\hat{y_i})) log(L(yi^))=i=1n(yilog(yi^)+(1yi)log(1yi^))
最大似然估计要求数越大越好,而损失函数要求越小越好,因而损失函数在前面加上负号,因而也得到了二分类问题使用的交叉熵损失函数
L o s s = − ∑ i = 1 n ( y i l o g ( y i ^ ) + ( 1 − y i ) l o g ( 1 − y i ^ ) ) Loss=-\sum_{i=1}^{n}(y_ilog(\hat{y_i})+(1-y_i)log(1-\hat{y_i})) Loss=i=1n(yilog(yi^)+(1yi)log(1yi^))

多分类问题

表示分类任务有多个类别,如对一堆水果分类,它们可能是橘子、苹果、梨等,每个样本有且只有一个标签。

这种情况与二分类类似,只是可能的情况增多了,可以描述为一个离散分布

y 1 y_{1} y1 y 2 y_2 y2 y k y_k yk
y 1 ^ \hat{y_1} y1^ y 2 ^ \hat{y_2} y2^ y k ^ \hat{y_k} yk^

注: y 1 、 y 2 . . . y k y_1、y_2...y_k y1y2...yk为真实值,其中有且只有一个为1,其余为0。(采用one-hot编码)

此时求解最大似然函数过程如下:
L ( y i ^ ) = Π i = 1 n ( y ( i , 1 ) ^ y ( i , 1 ) y ( i , 2 ) ^ y ( i , 2 ) . . . y ( i , n ) ^ y ( i , n ) ) L(\hat{y_i})=\Pi_{i=1}^{n}(\hat{y_{(i,1)}}^{y_{(i,1)}}\hat{y_{(i,2)}}^{y_{(i,2)}}...\hat{y_{(i,n)}}^{y_{(i,n)}}) L(yi^)=Πi=1n(y(i,1)^y(i,1)y(i,2)^y(i,2)...y(i,n)^y(i,n))
因为真实值只有一个为1,其余为0,因而只有1项值非零,可化简为:
L ( y i ^ ) = Π i = 1 n y ( i , m ) ^ y ( i , m ) L(\hat{y_i})=\Pi_{i=1}^{n}\hat{y_{(i,m)}}^{y_{(i,m)}} L(yi^)=Πi=1ny(i,m)^y(i,m)
注: y ( i , m ) ^ \hat{y_{(i,m)}} y(i,m)^表示含义为第i个样本,属于第m个类别(m值会随样本的变化动态改变)

两边同时取对数:
l o g ( L ( y i ^ ) ) = ∑ i = 1 n y ( i , m ) l o g ( y i , m ^ ) log(L(\hat{y_i}))=\sum_{i=1}^{n}y_{(i,m)}log(\hat{y_{i,m}}) log(L(yi^))=i=1ny(i,m)log(yi,m^)
与二元分类同理,此时多分类的交叉熵损失函数即为:
L o s s = − ∑ i = 1 n y ( i , m ) l o g ( y i , m ^ ) Loss=-\sum_{i=1}^{n}y_{(i,m)}log(\hat{y_{i,m}}) Loss=i=1ny(i,m)log(yi,m^)

参考文献

[1] https://www.bilibili.com/video/BV1a5411W7Dn?t=47
[2] https://juejin.cn/post/6844903630479294477

  • 8
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值