Pytorch CrossEntropyLoss() 原理和用法详解


1. 前言

在 PyTorch 中,CrossEntropyLoss() 是一个用于计算交叉熵损失(Cross-Entropy Loss)的损失函数。它通常用于多类别分类任务中,特别是当类别之间不平衡或者样本数目不均衡时。

官方文档用法参考这里

2. 交叉熵的含义

交叉熵(Cross-Entropy)是一种用于比较两个概率分布之间差异的度量。在机器学习中,交叉熵通常用作损失函数,用于衡量模型预测与真实标签之间的差异,尤其在分类任务中广泛使用。

假设有两个概率分布 P P P Q Q Q,其中 P P P 表示真实分布, Q Q Q 表示模型的预测分布。这两个分布都是离散的,通常用于表示类别的概率分布。交叉熵损失函数的计算方式如下:

H ( P , Q ) = − ∑ i P ( i ) log ⁡ ( Q ( i ) ) H(P, Q) = - \sum_{i} P(i) \log(Q(i)) H(P,Q)=iP(i)log(Q(i))

其中, i i i 表示类别的索引, P ( i ) P(i) P(i) Q ( i ) Q(i) Q(i) 分别表示真实分布和预测分布中第 i i i 个类别的概率。交叉熵衡量了在真实分布下观察到的事件的平均信息量,与预测分布 Q Q Q 相对应。

在分类任务中,通常使用交叉熵损失函数来衡量模型预测的概率分布与真实标签的差异。在训练过程中,模型的目标是最小化交叉熵损失,使得模型的预测分布尽可能接近真实分布。

交叉熵越小,越接近真实模型。当模型的预测与真实标签完全一致时,交叉熵达到最小值为 0。

3. 举例计算交叉熵

假设我们有一个分类任务,共有 3 个类别,并且模型的预测结果和真实标签如下:

  • 5个样本所属的真实标签(Ground Truth):[1, 0, 2, 1, 2]
  • 模型的预测概率分布:
    • 类别 0 的预测概率分布:[0.2, 0.6, 0.2]
    • 类别 1 的预测概率分布:[0.7, 0.2, 0.1]
    • 类别 2 的预测概率分布:[0.1, 0.1, 0.8]

首先,我们需要计算每个样本的交叉熵损失,然后将它们求和并除以样本数量来计算平均损失。计算过程如下:

(1)对于第一个样本(真实标签为 1):

  • 真实标签概率分布:[0, 1, 0]
  • 模型预测概率分布:[0.7, 0.2, 0.1]
  • 交叉熵损失:-1 * (1 * log(0.7) + 0 * log(0.2) + 0 * log(0.1)) ≈ 0.36

(2)对于第二个样本(真实标签为 0):

  • 真实标签概率分布:[1, 0, 0]
  • 模型预测概率分布:[0.2, 0.6, 0.2]
  • 交叉熵损失:-1 * (0 * log(0.2) + 1 * log(0.6) + 0 * log(0.2)) ≈ 0.51

(3)对于第三个样本(真实标签为 2):

  • 真实标签概率分布:[0, 0, 1]
  • 模型预测概率分布:[0.1, 0.1, 0.8]
  • 交叉熵损失:-1 * (0 * log(0.1) + 0 * log(0.1) + 1 * log(0.8)) ≈ 0.22

(4)对于第四个样本(真实标签为 1):

  • 真实标签概率分布:[0, 1, 0]
  • 模型预测概率分布:[0.7, 0.2, 0.1]
  • 交叉熵损失:-1 * (1 * log(0.7) + 0 * log(0.2) + 0 * log(0.1)) ≈ 0.36

(5)对于第五个样本(真实标签为 2):

  • 真实标签概率分布:[0, 0, 1]
  • 模型预测概率分布:[0.1, 0.1, 0.8]
  • 交叉熵损失:-1 * (0 * log(0.1) + 0 * log(0.1) + 1 * log(0.8)) ≈ 0.22

最后,将每个样本的交叉熵损失相加,并除以样本数量得到平均损失:
平均损失 = 0.36 + 0.51 + 0.22 + 0.36 + 0.22 5 ≈ 0.334 \text{平均损失} = \frac{0.36 + 0.51 + 0.22 + 0.36 + 0.22}{5} \approx 0.334 平均损失=50.36+0.51+0.22+0.36+0.220.334

所以,该多分类任务的平均交叉熵损失约为 0.334。

4. CrossEntropyLoss 计算过程

4.1 具体过程

Pytorch 中 CrossEntropyLoss() 函数包含以下步骤:

  1. softmax
  2. log
  3. NLLLoss

以下是验证流程:

import torch
import torch.nn as nn

_input = torch.randn(4, 3)
print('input:\n', _input)

target = torch.tensor([1, 2, 0, 1])  # 设置输出具体值 

################# 输出:#################
input:
 tensor([[-0.0251, -1.0660, -1.2555],
        [ 0.4511,  1.4464,  0.9722],
        [ 0.3108,  0.4180, -0.4181],
        [ 1.0811, -1.6097, -0.6413]])
# 计算输入softmax
softmax_f = nn.Softmax(dim=1)
soft_output = softmax_f(_input)
print('softmax_output:\n', soft_output)

# 在softmax的基础上取log
log_output = torch.log(soft_output)
print('log_output:\n', log_output)

################# 输出:#################
softmax_output:
 tensor([[0.6078, 0.2146, 0.1776],
        [0.1855, 0.5020, 0.3124],
        [0.3853, 0.4289, 0.1859],
        [0.8023, 0.0544, 0.1433]])
log_output:
 tensor([[-0.4979, -1.5388, -1.7284],
        [-1.6845, -0.6891, -1.1633],
        [-0.9538, -0.8466, -1.6828],
        [-0.2203, -2.9111, -1.9427]])
# softmax+log与nn.LogSoftmaxloss的结果是一致的。
logsoftmax_func = nn.LogSoftmax(dim=1)
logsoftmax_output = logsoftmax_func(_input)
print('logsoftmax_output:\n', logsoftmax_output)

################# 输出:#################
logsoftmax_output:
 tensor([[-0.4979, -1.5388, -1.7284],
        [-1.6845, -0.6891, -1.1633],
        [-0.9538, -0.8466, -1.6828],
        [-0.2203, -2.9111, -1.9427]])
# 先用nn.NLLLoss()计算
nllloss_func = nn.NLLLoss()
nlloss_output = nllloss_func(logsoftmax_output, target)
print('nlloss_output:\n', nlloss_output)

# 和nn.CrossEntropyLoss()的结果是一样的
crossentropyloss = nn.CrossEntropyLoss()
crossentropyloss_output = crossentropyloss(_input, target)
print('crossentropyloss_output:\n', crossentropyloss_output)

################# 输出:#################
nlloss_output:
 tensor(1.6417)
crossentropyloss_output:
 tensor(1.6417)

上述过程验证了CrossEntropyLoss() 函数包含以下步骤:

  1. softmax
  2. log
  3. NLLLoss

4.2 nn.NLLLoss() 的用法

下面有必要介绍 nn.NLLLoss() 的用法。在 PyTorch 中,NLLLoss() 是一个用于计算负对数似然损失(Negative Log Likelihood Loss)的损失函数。官方文档参考这里

例子:

from torch import nn
import torch

# 初始化
nllloss = nn.NLLLoss() # 可选参数中有 reduction='mean', 'sum', 默认mean

# 两个张量,一个是预测向量,一个是真实标签label
predict = torch.Tensor([[2, 3, 1],
                        [3, 7, 9]])
label = torch.tensor([1, 2])
v = nllloss(predict, label)
print(v)

################# 输出:#################
tensor(-6.)

解释:
上面的label 表示依次在 predict 选取值。例如,上面的 label = [1,2],那么在predict[0]中选取3,在predict[1]中选取9。然后求平均值并取负: − ( 3 + 9 ) / 2 = − 6 -(3+9)/2=-6 (3+9)/2=6

5. 参考

nn.CrossEntropyLoss
nn.NLLLoss

欢迎关注本人,我是喜欢搞事的程序猿;一起进步,一起学习;

欢迎关注知乎/CSDN:SmallerFL

也欢迎关注我的wx公众号(精选高质量文章):一个比特定乾坤

  • 12
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

SmallerFL

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值