【PyTorch】torch.nn.functional.cross_entropy() 函数:分类任务的交叉熵损失函数

torch.nn.functional.cross_entropy

torch.nn.functional.cross_entropyPyTorch 中用于分类任务的交叉熵损失函数,用于衡量 预测概率分布与真实类别分布之间的差异,常用于 多分类任务(multi-class classification)


1. 交叉熵损失的数学公式

对于 单个样本,交叉熵损失的计算公式为:
L = − ∑ i = 1 C y i log ⁡ ( y i ^ ) \mathcal{L} = -\sum_{i=1}^{C} y_i \log (\hat{y_i}) L=i=1Cyilog(yi^)
其中:

  • C C C:类别总数。
  • y i y_i yi:真实类别的 one-hot 编码
  • y i ^ \hat{y_i} yi^:模型预测的概率分布(经过 Softmax)。

在 PyTorch 中,cross_entropy 直接接受未经过 Softmax 变换的 logits,并且内部会 自动计算 Softmax 并进行对数计算,提高数值稳定性。


2. torch.nn.functional.cross_entropy 的语法

torch.nn.functional.cross_entropy(input, target, weight=None, ignore_index=-100, reduction='mean')
参数说明
input模型的 logits 输出(未经过 softmax)
target真实类别索引(不是 one-hot 编码)
weight每个类别的损失权重(用于类别不平衡问题)
ignore_index忽略的类别索引(通常用于 padding
reductionmean(默认):取均值,sum:取总和,none:逐样本损失

3. 示例:计算交叉熵损失

import torch
import torch.nn.functional as F

# 假设有 3 个类别
logits = torch.tensor([[2.0, 0.5, 1.0], [0.5, 2.0, 1.5]])  # 模型输出(未经过 Softmax)
targets = torch.tensor([0, 2])  # 真实类别索引

# 计算交叉熵损失
loss = F.cross_entropy(logits, targets)
print(loss)

解析

  • logits 形状为 (batch_size, num_classes)
  • targets 形状为 (batch_size,),包含每个样本的真实类别索引。
  • 内部会自动计算 softmax 和 log,避免数值不稳定。

4. cross_entropy vs nll_loss

PyTorch 提供 nll_loss(负对数似然损失),但一般不直接使用,而是搭配 log_softmax

logits = torch.tensor([[2.0, 0.5, 1.0], [0.5, 2.0, 1.5]])
log_probs = torch.log_softmax(logits, dim=1)  # 先计算 log_softmax
targets = torch.tensor([0, 2])

# 使用 nll_loss(需要 log softmax 作为输入)
loss_nll = F.nll_loss(log_probs, targets)

# cross_entropy = softmax + log + nll_loss
loss_ce = F.cross_entropy(logits, targets)

print(loss_nll, loss_ce)  # 两者结果相同

结论

  • F.cross_entropy(logits, target) = softmax + log + nll_loss
  • cross_entropy 计算更稳定,建议优先使用

5. 处理类别不平衡

如果类别不均衡,可以使用 weight 参数:

class_weights = torch.tensor([1.0, 2.0, 3.0])  # 权重:类别 0 最小,类别 2 最大
loss = F.cross_entropy(logits, targets, weight=class_weights)
print(loss)

作用

  • 权重较大的类别损失更大,迫使模型关注少数类别

6. ignore_index 用于 padding

在 NLP 任务(如序列标注)中,可使用 ignore_index 忽略 padding

targets = torch.tensor([0, 2, -1])  # -1 代表 padding
loss = F.cross_entropy(logits, targets, ignore_index=-1)
print(loss)

作用

  • 不计算 padding 位置的损失,适用于 RNN、Transformer 等 NLP 任务

7. reduction 参数

reduction 控制损失计算方式:

loss_none = F.cross_entropy(logits, targets, reduction="none")  # 返回每个样本的损失
loss_sum = F.cross_entropy(logits, targets, reduction="sum")  # 总和
loss_mean = F.cross_entropy(logits, targets, reduction="mean")  # 默认均值

print(loss_none, loss_sum, loss_mean)

作用

  • none:返回逐样本损失,适用于 需要自定义损失计算 的任务。
  • sum:计算总损失。
  • mean:计算均值(默认)。

8. 适用场景

  • 图像分类(CNN)
  • 文本分类(Transformer)
  • 序列标注(ignore_index=-1 处理 padding
  • 处理类别不均衡数据(weight 选项)

9. 结论

  • torch.nn.functional.cross_entropy 是 PyTorch 中最常用的分类损失函数
  • 内部包含 softmax + log + nll_loss,无需手动计算 softmax
  • 适用于多分类问题input 为 logits,target 为类别索引)。
  • 可以使用 weight 处理类别不均衡,使用 ignore_index 处理 padding

在 PyTorch 分类任务中,推荐使用 F.cross_entropy 作为标准损失函数。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值