【多标签单分类】交叉熵计算

对https://blog.csdn.net/zfhsfdhdfajhsr/article/details/124689632这篇笔记进行补充,其中这个博主给的例子有错误。具体错误自己跑了代码就知道了,不细说,看看我写的代码就知道了:

target = torch.tensor([0,2,1]) # shape(3)
input_ = torch.tensor([[0.13, -0.18, 0.87],
                       [0.25, -0.04, 0.32],
                       [0.24, -0.54, 0.53]]) # shape(3,3)
loss_item = torch.nn.CrossEntropyLoss()
loss = loss_item(input_.view(-1,3), target.view(-1))

这是一个单标签多分类的问题,这个例子一共有3个类,要从三个类中选择一个类,将预测值和真实值输入进交叉熵损失函数,首先会对输入值进行softmax,归一化到0-1之间,然后再计算。下面是我的手动推理过程。

# 第一个样本的计算
x_01 = 0.13
x_02 = -0.18
x_03 = 0.87

loss_1 = -log(e^(0.13) / (e^(0.13) + e^(-0.18) + e^(0.87)))-log(0.593 / (0.593 + 0.832 + 2.392))-log(0.593 / 3.817)-log(0.1553)1.8602
       
# 第二个样本的计算
x_11 = 0.25
x_12 = -0.04
x_13 = 0.32

loss_2 = -log(e^(-0.04) / (e^(0.25) + e^(-0.04) + e^(0.32)))-log(0.961 / (1.284 + 0.961 + 1.379))-log(0.961 / 3.624)-log(0.2656)1.3289
       
# 第三个样本的计算
x_21 = 0.24
x_22 = -0.54
x_23 = 0.53

loss_3 = -log(e^(0.53) / (e^(0.24) + e^(-0.54) + e^(0.53)))-log(1.701 / (1.272 + 0.582 + 1.701))-log(1.701 / 3.555)-log(0.4787)0.7346

# 最后求的他们平均值
loss = (loss_1 + loss_2 + loss_3) / N
     = (1.8602 + 1.3289 + 0.7346) / 31.3079

最后总结得到的数学公式为:

loss = -(1/N) * Σ(log(e^(x_ij) / Σ(e^(x_ik))), i=1 to N)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值