torch.nn.functional.cross_entropy
torch.nn.functional.cross_entropy
是 PyTorch 中用于分类任务的交叉熵损失函数,用于衡量 预测概率分布与真实类别分布之间的差异,常用于 多分类任务(multi-class classification)。
1. 交叉熵损失的数学公式
对于 单个样本,交叉熵损失的计算公式为:
L
=
−
∑
i
=
1
C
y
i
log
(
y
i
^
)
\mathcal{L} = -\sum_{i=1}^{C} y_i \log (\hat{y_i})
L=−i=1∑Cyilog(yi^)
其中:
- C C C:类别总数。
- y i y_i yi:真实类别的 one-hot 编码。
- y i ^ \hat{y_i} yi^:模型预测的概率分布(经过 Softmax)。
在 PyTorch 中,cross_entropy
直接接受未经过 Softmax 变换的 logits,并且内部会 自动计算 Softmax 并进行对数计算,提高数值稳定性。
2. torch.nn.functional.cross_entropy
的语法
torch.nn.functional.cross_entropy(input, target, weight=None, ignore_index=-100, reduction='mean')
参数 | 说明 |
---|---|
input | 模型的 logits 输出(未经过 softmax) |
target | 真实类别索引(不是 one-hot 编码) |
weight | 每个类别的损失权重(用于类别不平衡问题) |
ignore_index | 忽略的类别索引(通常用于 padding ) |
reduction | mean (默认):取均值,sum :取总和,none :逐样本损失 |
3. 示例:计算交叉熵损失
import torch
import torch.nn.functional as F
# 假设有 3 个类别
logits = torch.tensor([[2.0, 0.5, 1.0], [0.5, 2.0, 1.5]]) # 模型输出(未经过 Softmax)
targets = torch.tensor([0, 2]) # 真实类别索引
# 计算交叉熵损失
loss = F.cross_entropy(logits, targets)
print(loss)
解析
logits
形状为(batch_size, num_classes)
。targets
形状为(batch_size,)
,包含每个样本的真实类别索引。- 内部会自动计算 softmax 和 log,避免数值不稳定。
4. cross_entropy
vs nll_loss
PyTorch 提供 nll_loss
(负对数似然损失),但一般不直接使用,而是搭配 log_softmax
:
logits = torch.tensor([[2.0, 0.5, 1.0], [0.5, 2.0, 1.5]])
log_probs = torch.log_softmax(logits, dim=1) # 先计算 log_softmax
targets = torch.tensor([0, 2])
# 使用 nll_loss(需要 log softmax 作为输入)
loss_nll = F.nll_loss(log_probs, targets)
# cross_entropy = softmax + log + nll_loss
loss_ce = F.cross_entropy(logits, targets)
print(loss_nll, loss_ce) # 两者结果相同
结论
F.cross_entropy(logits, target)
=softmax + log + nll_loss
cross_entropy
计算更稳定,建议优先使用。
5. 处理类别不平衡
如果类别不均衡,可以使用 weight
参数:
class_weights = torch.tensor([1.0, 2.0, 3.0]) # 权重:类别 0 最小,类别 2 最大
loss = F.cross_entropy(logits, targets, weight=class_weights)
print(loss)
作用
- 权重较大的类别损失更大,迫使模型关注少数类别。
6. ignore_index
用于 padding
在 NLP 任务(如序列标注)中,可使用 ignore_index
忽略 padding
:
targets = torch.tensor([0, 2, -1]) # -1 代表 padding
loss = F.cross_entropy(logits, targets, ignore_index=-1)
print(loss)
作用
- 不计算
padding
位置的损失,适用于 RNN、Transformer 等 NLP 任务。
7. reduction
参数
reduction
控制损失计算方式:
loss_none = F.cross_entropy(logits, targets, reduction="none") # 返回每个样本的损失
loss_sum = F.cross_entropy(logits, targets, reduction="sum") # 总和
loss_mean = F.cross_entropy(logits, targets, reduction="mean") # 默认均值
print(loss_none, loss_sum, loss_mean)
作用
none
:返回逐样本损失,适用于 需要自定义损失计算 的任务。sum
:计算总损失。mean
:计算均值(默认)。
8. 适用场景
- 图像分类(CNN)
- 文本分类(Transformer)
- 序列标注(ignore_index=-1 处理
padding
) - 处理类别不均衡数据(
weight
选项)
9. 结论
torch.nn.functional.cross_entropy
是 PyTorch 中最常用的分类损失函数。- 内部包含 softmax + log + nll_loss,无需手动计算
softmax
。 - 适用于多分类问题(
input
为 logits,target
为类别索引)。 - 可以使用
weight
处理类别不均衡,使用ignore_index
处理padding
。
在 PyTorch 分类任务中,推荐使用 F.cross_entropy
作为标准损失函数。