criterion = nn.NLLLoss()
是 PyTorch 中的一个损失函数,全称为 Negative Log Likelihood Loss,即负对数似然损失函数。虽然它通常与分类问题相关联,但严格来说,它并不是直接计算交叉熵损失(Cross-Entropy Loss)的函数。然而,在分类问题的上下文中,特别是在使用 softmax 激活函数时,NLLLoss 与交叉熵损失在数学上是等价的,只是它们的实现方式略有不同。
NLLLoss
NLLLoss 计算的是实际标签所对应的预测概率的负对数值的平均值(或总和,取决于 reduction
参数)。但是,请注意,为了使用 NLLLoss,你需要将模型的输出(logits)传递给 softmax 函数(或直接在 PyTorch 中使用 log_softmax
)以获取对数概率,然后将这些对数概率和类别索引(而不是 one-hot 编码的标签)作为输入传递给 NLLLoss。
交叉熵损失函数
交叉熵损失函数是衡量两个概率分布差异的一种方法,在分类问题中,它通常用于衡量模型预测的概率分布与真实标签的概率分布之间的差异。在 PyTorch 中,交叉熵损失可以直接通过 nn.CrossEntropyLoss()
来实现,它结合了 softmax 激活函数和 NLLLoss 的功能,使得你不需要显式地将 logits 转换为概率分布。
PyTorch 中的表示
-
NLLLoss:
import torch
import torch.nn as nn
# 假设 logits 是模型的原始输出
logits = torch.randn(3, 5, requires_grad=True) # 3个样本,5个类别的logits
labels = torch.tensor([2, 4, 1]) # 真实标签的类别索引
# 注意:在实际应用中,你可能需要先对 logits 应用 softmax 或 log_softmax
# 但为了使用 NLLLoss,我们直接传递 logits 和 labels
# 然而,这里有一个误解,因为 NLLLoss 实际上期望的是 log_probabilities
# 正确的方式是使用 log_softmax,但 PyTorch 的 CrossEntropyLoss 已经为我们做了这些
# 如果要使用 NLLLoss,你应该这样(但通常不推荐这样直接做,因为它跳过了 softmax):
# log_probs = torch.nn.functional.log_softmax(logits, dim=1)
# loss = nn.NLLLoss()(log_probs, labels)
# 更简单且常见的方法是使用 CrossEntropyLoss
criterion_nll = nn.NLLLoss()
# 但请注意,这里我们不直接使用它,因为前面提到的原因
# 使用 CrossEntropyLoss(推荐)
criterion_ce = nn.CrossEntropyLoss()
loss_ce = criterion_ce(logits, labels)
总结一下,正确的做法是先对 logits 应用 log_softmax
(或直接在 PyTorch 中使用 F.log_softmax
,其中 F
是 torch.nn.functional
的常用别名),然后将结果和标签传递给 NLLLoss
实例的 __call__
方法来计算损失。这是处理多分类问题时使用 NLLLoss
的标准方式,尽管在 PyTorch 中更常见的是直接使用 CrossEntropyLoss
,因为它内部已经包含了 softmax 和 NLLLoss 的计算。
-
CrossEntropyLoss:
如上所示,nn.CrossEntropyLoss()
直接接受 logits 和类别索引作为输入,并在内部计算 softmax 和 NLLLoss。这是处理多分类问题时更常见和推荐的方法。