交叉熵损失(Cross Entropy Loss)计算过程

在机器学习中(特别是分类模型),模型训练时,通常都是使用交叉熵(Cross-Entropy)作为损失进行最小化:

C E ( p , q ) = − ∑ i = 1 C p i l o g ( q i ) CE(p,q)=- \sum_{i=1}^{C} p_i log(q_i) CE(p,q)=i=1Cpilog(qi)
其中 C C C代表类别数。 p i p_i pi为真实, q i q_i qi为预测。

我们以MNIST多分类为例,通常Label会编码为One-Hot,最后一层输出会使用Softmax函数进行概率化输出,如下表所示:

SampleTruePredicted
这里写图片描述[0, 1, 0, 0, 0, 0, 0, 0, 0, 0][0.1, 0.6, 0.3, 0, 0, 0, 0, 0, 0, 0]
这里写图片描述[0, 0, 0, 0, 1, 0, 0, 0, 0, 0][0, 0.3, 0.2, 0, 0.5, 0, 0, 0, 0, 0]
这里写图片描述[0, 0, 0, 0, 0, 1, 0, 0, 0, 0][0.6, 0.3, 0, 0, 0, 0.1, 0, 0, 0, 0]

对于第一个样本,交叉熵损失为:
− l n ( 0.6 ) ≈ 0.51 -ln(0.6) \approx 0.51 ln(0.6)0.51

对于第二个样本,交叉熵损失为:
− l n ( 0.5 ) ≈ 0.69 -ln(0.5) \approx 0.69 ln(0.5)0.69

对于第三个样本,交叉熵损失为:
− l n ( 0.1 ) ≈ 2.30 -ln(0.1) \approx 2.30 ln(0.1)2.30

平均交叉熵损失为:
− ( l n ( 0.6 ) + l n ( 0.5 ) + l n ( 0.1 ) ) 3 ≈ 1.17 -\frac{(ln(0.6)+ln(0.5)+ln(0.1))}{3} \approx 1.17 3(ln(0.6)+ln(0.5)+ln(0.1))1.17

从上面的计算可以知道,预测越准,损失越小。

Scikit-learn中提供了交叉熵损失的计算方法:

from sklearn.metrics import log_loss

true = ['1', '4', '5']
pred=[[0.1, 0.6, 0.3, 0, 0, 0, 0, 0, 0, 0],
      [0, 0.3, 0.2, 0, 0.5, 0, 0, 0, 0, 0],
      [0.6, 0.3, 0, 0, 0, 0.1, 0, 0, 0, 0]]
labels=['0','1','2','3','4','5','6','7','8','9']

log_loss(true, pred, labels)

Out:
1.1688526324400008

为什么训练时采取交叉熵损失,而不用均方误差(Mean Squared Error, MSE)呢?

Why You Should Use Cross-Entropy Error Instead Of Classification Error Or Mean Squared Error For Neural Network Classifier Training -> 翻译版

  • 26
    点赞
  • 108
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 7
    评论
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

手撕机

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值