二分类和多分类交叉熵函数区别详解
写在前面
查了下百度,交叉熵,是度量两个分布间差异的概念。而在我们神经网络中,两个分布也就是y的真实值分布和预测值分布。当两个分布越接近时,其交叉熵值也就越小。
根据上面知识,也就转化为我们需要解决让预测值和真实值尽可能接近的问题,而这正与概率论数理统计中的最大似然分布一脉相承,进而目标转化为确定值的分布和求解最大似然估计问题。
二分类问题
表示分类任务中有两个类别,比如我们想判断一张图片是不是猫。也就是说,训练一个分类器,输入一张图片,用特征向量x表示,输出是不是猫用y=0或1表示,其中1表示是,0表示不是。
这样的问题,我们完全可以用0-1分布来进行表示:
y i y_i yi | 1 − y i 1-y_i 1−yi |
---|---|
y i ^ \hat{y_i} yi^ | 1 − y i ^ 1-\hat{y_i} 1−yi^ |
注:其中yi为真实值, y i ^ \hat{y_i} yi^为预测值,且 y i y_i yi的值为0或1
此时求解最大似然估计过程如下:
L
(
y
i
^
)
=
Π
i
=
1
n
y
i
^
y
i
(
1
−
y
i
^
)
1
−
y
i
L(\hat{y_i})=\Pi_{i=1}^{n}\hat{y_i}^{y_i}(1-\hat{y_i})^{1-y_i}
L(yi^)=Πi=1nyi^yi(1−yi^)1−yi
两边同时取对数
l
o
g
(
L
(
y
i
^
)
)
=
∑
i
=
1
n
(
y
i
l
o
g
(
y
i
^
)
+
(
1
−
y
i
)
l
o
g
(
1
−
y
i
^
)
)
log(L(\hat{y_i}))=\sum_{i=1}^{n}(y_ilog(\hat{y_i})+(1-y_i)log(1-\hat{y_i}))
log(L(yi^))=i=1∑n(yilog(yi^)+(1−yi)log(1−yi^))
最大似然估计要求数越大越好,而损失函数要求越小越好,因而损失函数在前面加上负号,因而也得到了二分类问题使用的交叉熵损失函数。
L
o
s
s
=
−
∑
i
=
1
n
(
y
i
l
o
g
(
y
i
^
)
+
(
1
−
y
i
)
l
o
g
(
1
−
y
i
^
)
)
Loss=-\sum_{i=1}^{n}(y_ilog(\hat{y_i})+(1-y_i)log(1-\hat{y_i}))
Loss=−i=1∑n(yilog(yi^)+(1−yi)log(1−yi^))
多分类问题
表示分类任务有多个类别,如对一堆水果分类,它们可能是橘子、苹果、梨等,每个样本有且只有一个标签。
这种情况与二分类类似,只是可能的情况增多了,可以描述为一个离散分布
y 1 y_{1} y1 | y 2 y_2 y2 | … | y k y_k yk |
---|---|---|---|
y 1 ^ \hat{y_1} y1^ | y 2 ^ \hat{y_2} y2^ | … | y k ^ \hat{y_k} yk^ |
注: y 1 、 y 2 . . . y k y_1、y_2...y_k y1、y2...yk为真实值,其中有且只有一个为1,其余为0。(采用one-hot编码)
此时求解最大似然函数过程如下:
L
(
y
i
^
)
=
Π
i
=
1
n
(
y
(
i
,
1
)
^
y
(
i
,
1
)
y
(
i
,
2
)
^
y
(
i
,
2
)
.
.
.
y
(
i
,
n
)
^
y
(
i
,
n
)
)
L(\hat{y_i})=\Pi_{i=1}^{n}(\hat{y_{(i,1)}}^{y_{(i,1)}}\hat{y_{(i,2)}}^{y_{(i,2)}}...\hat{y_{(i,n)}}^{y_{(i,n)}})
L(yi^)=Πi=1n(y(i,1)^y(i,1)y(i,2)^y(i,2)...y(i,n)^y(i,n))
因为真实值只有一个为1,其余为0,因而只有1项值非零,可化简为:
L
(
y
i
^
)
=
Π
i
=
1
n
y
(
i
,
m
)
^
y
(
i
,
m
)
L(\hat{y_i})=\Pi_{i=1}^{n}\hat{y_{(i,m)}}^{y_{(i,m)}}
L(yi^)=Πi=1ny(i,m)^y(i,m)
注:
y
(
i
,
m
)
^
\hat{y_{(i,m)}}
y(i,m)^表示含义为第i个样本,属于第m个类别(m值会随样本的变化动态改变)。
两边同时取对数:
l
o
g
(
L
(
y
i
^
)
)
=
∑
i
=
1
n
y
(
i
,
m
)
l
o
g
(
y
i
,
m
^
)
log(L(\hat{y_i}))=\sum_{i=1}^{n}y_{(i,m)}log(\hat{y_{i,m}})
log(L(yi^))=i=1∑ny(i,m)log(yi,m^)
与二元分类同理,此时多分类的交叉熵损失函数即为:
L
o
s
s
=
−
∑
i
=
1
n
y
(
i
,
m
)
l
o
g
(
y
i
,
m
^
)
Loss=-\sum_{i=1}^{n}y_{(i,m)}log(\hat{y_{i,m}})
Loss=−i=1∑ny(i,m)log(yi,m^)
参考文献
[1] https://www.bilibili.com/video/BV1a5411W7Dn?t=47
[2] https://juejin.cn/post/6844903630479294477