【Pytorch基础】torch.nn.CrossEntropyLoss损失函数介绍

1 交叉熵的定义

  交叉熵主要是用来判定实际的输出与期望的输出的接近程度,为什么这么说呢,举个例子:在做分类的训练的时候,如果一个样本属于第K类,那么这个类别所对应的输出节点的输出值应该为1,而其他节点的输出都为0,即[0,0,1,0,….0,0],这个数组也就是样本的Label,是神经网络最期望的输出结果。也就是说用它来衡量网络的输出与标签的差异,利用这种差异经过反向传播去更新网络参数。参考文献【1】

2 交叉熵的数学原理

Pytorch常用的交叉熵损失函数CrossEntropyLoss()详解

3 Pytorch交叉熵实现

3.1 举个栗子

  交叉熵损失,是分类任务中最常用的一个损失函数。在Pytorch中是基于下面的公式实现的。
Loss ⁡ ( x ^ , x ) = − ∑ i = 1 n x log ⁡ ( x ^ ) \operatorname{Loss}(\hat{x}, x)=-\sum_{i=1}^{n} x \log (\hat{x}) Loss(x^,x)=i=1nxlog(x^)
其中 x x x是真实标签, x ^ \hat{x} x^ 是预测的类分布(通常是使用softmax将模 型输出转换为概率分布)。
  取单个样本举例, 假设 x 1 = [ 0 , 1 , 0 ] x_1=[0, 1, 0] x1=[0,1,0], 模型预测样本 x 1 x_1 x1的概率为 x 1 ^ = [ 0.1 , 0.5 , 0.4 ] \hat{x_1}=[0.1, 0.5, 0.4] x1^=[0.1,0.5,0.4](因为是分布, 所以属于各个类的和为1)。则样本的损失计算如下所示:

   Loss ⁡ ( x 1 ^ , x 1 ) = − 0 × log ⁡ ( 0.1 ) − 1 × log ⁡ ( 0.5 ) − 0 × log ⁡ ( 0.4 ) = log ⁡ ( 0.5 ) \operatorname{Loss}(\hat{x_1}, x_1)=-0 \times \log (0.1)-1 \times \log (0.5)-0 \times \log (0.4)=\log (0.5) Loss(x1^,x1)=0×log(0.1)1×log(0.5)0×log(0.4)=log(0.5)

更详细的多分类交叉熵损失函数的例子可以参考文献【4】

3.2 Pytorch实现

实际使用中需要注意几点:

  1. torch.nn.CrossEntropyLoss(input, target)中的标签target使用的不是one-hot形式,而是类别的序号。形如 target = [1, 3, 2] 表示3个样本分别属于第1类、第3类、第2类。(单标签多分类问题)
  2. torch.nn.CrossEntropyLoss(input, target)的input是没有归一化的每个类的得分,而不是softmax之后的分布。

  输入的形式大概如下所示

import torch
target = [1, 3, 2]

input_ = [[0.13, -0.18, 0.87],
		 [0.25, -0.04, 0.32],
		 [0.24, -0.54, 0.53]]
# 然后就将他们扔到CrossEntropyLoss函数中,就可以得到损失。
loss_item = torch.nn.CrossEntropyLoss()
loss = loss_item(input, target)

CrossEntropyLoss函数里面的实现,如下所示:

def forward(self, input, target):
    return F.cross_entropy(input, target, weight=self.weight,
                           ignore_index=self.ignore_index, reduction=self.reduction)

是调用的torch.nn.functional(俗称F)中的cross_entropy()函数。

  此处需要区分一下:torch.nn.Module 和 torch.nn.functional(俗称F)中损失函数的区别。Module的损失函数例如CrossEntropyLoss、NLLLoss等是封装之后的损失函数类,是一个类,因此其中的变量可以自动维护。经常是对F中的函数的封装。而F中的损失函数只是单纯的函数。
下面看一下F.cross_entropy函数

3.3 F.cross_entropy

  • input:预测值,(batch,dim),这里dim就是要分类的总类别数
  • target:真实值,(batch),这里为啥是1维的?因为真实值并不是用one-hot形式表示,而是直接传类别id。
  • weight:指定权重,(dim),可选参数,可以给每个类指定一个权重。通常在训练数据中不同类别的样本数量差别较大时,可以使用权重来平衡。
  • ignore_index:指定忽略一个真实值,(int),也就是手动忽略一个真实值。
  • reduction:在[none, mean, sum]中选,string型。none表示不降维,返回和target相同形状;mean表示对一个batch的损失求均值;sum表示对一个batch的损失求和。

其中参数weight、ignore_index、reduction要在实例化CrossEntropyLoss对象时指定,例如:

loss = torch.nn.CrossEntropyLoss(reduction='none')

F中的cross_entropy的实现

return nll_loss(log_softmax(input, dim=1), target, weight, None, ignore_index, None, reduction)

可以看到就是先调用log_softmax,再调用nll_loss。log_softmax就是先softmax再取log。

4 参考文献

[1]Pytorch常用损失函数拆解
[2]Pytorch常用的交叉熵损失函数CrossEntropyLoss()详解
[3]负对数似然(negative log-likelihood)
[4]损失函数|交叉熵损失函数

  • 25
    点赞
  • 101
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值