最全的交叉熵损失函数(Pytorch)

引言

这里主要讲述pytorch中的几种交叉熵损失类,熵是用来描述一个系统的混乱程度,通过交叉熵我们就能够确定预测数据与真是数据之间的相近程度。交叉熵越小,表示数据越接近真实样本。公式为:
在这里插入图片描述

在pytorch中,损失可以通过函数或者类来计算,这里BCELoss、BCEWithLogitsLoss、NLLLoss、CrossEntropyLoss都是类,需要先进行类的定义,再调用函数方法,例如:

import torch
one = torch.nn.CrossEntropyLoss()(pre, label)

当然也可以直接使用函数方法,例如:

import torch.nn.functional as F
one = F.cross_entropy(pre,label)

BCELoss

BCELoss是一个二分类损失函数,全称:Binary Cross Entropy Loss,是交叉熵损失函数应用于二分类损失的特殊形式,一般配合sigmoid使用。公式为:
在这里插入图片描述
例如:

import torch
import torch.nn.functional as F
pre_a = torch.tensor([
    [0.8, 0.2]
],dtype=torch.float)
pre_b = torch.tensor([
    [0.6, 0.4]
],dtype=torch.float)
target = torch.tensor([
    [1, 0]
],dtype=torch.float)
a = F.binary_cross_entropy(pre_a, target)
b = F.binary_cross_entropy(pre_b, target)
print(a, b)
# 结果:tensor(0.2231) tensor(0.5108)

对于预测值a[0.8,0.2]、预测值b[0.6,0.4]和标签[1,0],其中预测值都是经过sigmoid得到的,他们的和为1,带入公式中计算得:
在这里插入图片描述
笔试计算和代码计算结果一致。可以看出预测值a与标签更加接近,同时损失函数也更小,符合预期结果。当然使用多个二分类损失也可以达到多分类的效果,例如是不是猫,是不是狗,是不是熊,我最先看到这样实现的好像是在YoloV4中吧!

BCEWithLogitsLoss

BCEWithLogitsLoss = sigmoid + BCELoss,使用BCEWithLogisLoss会自动帮助预测值进行sigmoid计算。

pre_a = torch.tensor([
    [1.8, 0.2]
],dtype=torch.float)
pre_a_sigmoid = F.sigmoid(pre_a)
a = F.binary_cross_entropy(pre_a_sigmoid,target)
b = F.binary_cross_entropy_with_logits(pre_a,target)
print(a,b) # tensor(0.4756) tensor(0.4756)

NLLLoss

NLLoss又称负对数似然损失函数,用于处理多分类问题,输入是对数化的概率值。公式为:
在这里插入图片描述

import torch
import torch.nn
 
a = torch.Tensor([[1,2,3]])
nll = nn.NLLLoss()
target1 = torch.Tensor([0]).long()
target2 = torch.Tensor([1]).long()
target3 = torch.Tensor([2]).long()
 
#测试
n1 = nll(a,target1)
#输出:tensor(-1.)
n2 = nll(a,target2)
#输出:tensor(-2.)
n3 = nll(a,target3)
#输出:tensor(-3.)

NLLLoss负对数似然损失函数实现的是将标签target位置的预测值取出求反

CrossEntropyLoss

CrossEntropyLoss是一种用于多分类的损失函数,输入是未经过softmax的tensor型值。在这里插入图片描述
CrossEntropyLoss就是将softmax、log、NLLLoss合为一体。公式为:
在这里插入图片描述
下面我将通过三种方法实现多分类交叉熵损失函数:

  1. 一步实现 CrossEntropyLoss
  2. 二步实现 log_softmax+nll_loss
  3. 三步实现 softmax+log+nll_loss
pre = np.array([
    [0.5, 0.1],
    [0.3, 0.8],
    [0.6, 0.1]
])
target = np.array([
    [1, 0],
    [0, 1],
    [1, 0]
])
pre = torch.from_numpy(pre)
target = torch.from_numpy(target)

label = torch.argmax(target, dim=1)
label = torch.LongTensor(label)

# 一步实现
one = torch.nn.CrossEntropyLoss()(pre, label)
print(one)
one = F.cross_entropy(pre,label)
print(one)

print('*' * 20)

# 两步实现
pre_two = torch.nn.LogSoftmax(dim=1)(pre)
print(pre_two)
two = torch.nn.NLLLoss()(pre_two, label)
print(two)

print('*' * 20)

# 三步实现
pre_three = torch.nn.Softmax(dim=1)(pre)
print(pre_three)
pre_log = torch.log(pre_three)
print(pre_log)
three = torch.nn.NLLLoss()(pre_log, label)
print(three)

结果都是tensor(0.4871, dtype=torch.float64)

总结

交叉熵损失函数广泛应用于图像分类、分割中,可以发现,不管是二分类还是多分类,其实计算损失函数都经历三个步骤,其中步骤之间可以合并:

  • 激活函数,通过sigmoid或者softmax将预测值缩放到[0,1]之间
  • log操作,对计算缩放求取log操作,进一步缩放至[-无穷,0]之间
  • 累积求和,根据函数定义,将标签和缩放后的预测值进行相乘求和

参考

https://blog.csdn.net/sdu_hao/article/details/103499223
https://blog.csdn.net/watermelon1123/article/details/91044856
https://blog.csdn.net/qq_16949707/article/details/79929951
https://zhuanlan.zhihu.com/p/159477597

  • 6
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
交叉损失函数pytorch中的一个常用函数,用于衡量分类任务中模型预测结果与真实标签之间的差异。在pytorch中,交叉损失函数的定义为nn.CrossEntropyLoss()。 该函数结合了nn.LogSoftmax()和nn.NLLLoss()两个函数的功能。其中,nn.LogSoftmax()用于对模型的输出进行log softmax操作,将其转化为概率分布;nn.NLLLoss()则用于计算负对数似然损失。因此,nn.CrossEntropyLoss()可以直接接收模型的输出和真实标签作为输入,并自动进行相应的处理,避免了手动进行softmax和计算负对数似然损失的麻烦。 在使用nn.CrossEntropyLoss()时,可以通过参数进行进一步的定制,比如设置权重、忽略特定的类别等。具体参数包括weight、size_average、ignore_index、reduce和reduction等。可以根据实际需要进行调整。 总结而言,交叉损失函数pytorch中是一个方便且常用的函数,用于衡量模型的预测结果与真实标签之间的差异,并可通过参数进行进一步的定制。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *3* [【Pytorch交叉损失函数 CrossEntropyLoss() 详解](https://blog.csdn.net/weixin_44211968/article/details/123906631)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] - *2* [【pytorch交叉损失函数 nn.CrossEntropyLoss()](https://blog.csdn.net/weixin_37804469/article/details/125271074)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值