深入理解交叉熵损失CrossEntropyLoss - nn.NLLLoss(Negative Log-Likelihood Loss)

深入理解交叉熵损失 CrossEntropyLoss - nn.NLLLoss(Negative Log-Likelihood Loss)

flyfish

nn.NLLLoss

nn.NLLLoss是PyTorch中的一种损失函数,全称为负对数似然损失(Negative Log-Likelihood Loss)。这个损失函数通常与最后一层为 nn.LogSoftmax 的网络层一起使用,主要用于多分类问题。其公式为:

NLLLoss = − 1 N ∑ i = 1 N log ⁡ ( p i , y i ) \text{NLLLoss} = - \frac{1}{N} \sum_{i=1}^{N} \log(p_{i, y_i}) NLLLoss=N1i=1Nlog(pi,yi)

其中, N N N 是批次的样本数量, p i , y i p_{i, y_i} pi,yi 是第 i i i 个样本的真实类别 y i y_i yi 对应的概率。

代码实现

用 PyTorch 实现上述计算过程:

import torch
import torch.nn as nn

# 示例输入(未归一化的分数)(批次大小为3,类别数为3)
logits = torch.tensor([[2.0, 1.0, 0.1], 
                       [0.5, 2.5, 0.3], 
                       [1.2, 0.7, 1.5]])

# 真实标签
labels = torch.tensor([0, 1, 2])

# LogSoftmax 层
log_softmax = nn.LogSoftmax(dim=1)
log_probs = log_softmax(logits)

# NLLLoss 损失函数
nll_loss = nn.NLLLoss()
loss = nll_loss(log_probs, labels)

print(loss.item())

在这个示例中,log_softmax 将未归一化的分数转换为对数概率,nll_loss 计算负对数似然损失。最终输出的 loss 是一个标量,表示当前批次的平均损失。

在这个示例中:

  1. input 也就是logits ,是原始的网络输出(未归一化的得分)。
  2. log_softmax 将这些得分转换为对数概率。
  3. target 是真实的标签(0, 1 或 2)。
  4. nll_loss 计算负对数似然损失。

通过 LogSoftmax 层将 input 转换为对数概率,然后通过 NLLLoss 计算损失。这样,模型的训练目标就是最小化这个负对数似然损失。

注意事项

  • 输入格式:输入应该是对数概率(log probabilities),这通常通过 nn.LogSoftmax 层获得。如果输入是标准概率(通过 nn.Softmax 获得),则不能直接使用 nn.NLLLoss。
  • 标签格式:标签应该是类别的索引(整数形式),而不是 one-hot 编码。
    通过正确使用 nn.NLLLoss,可以有效地训练多分类问题中的神经网络。

在 PyTorch 中,如果使用 nn.CrossEntropyLoss 或 nn.NLLLoss,则需要将标签表示为类别索引(整数形式)。这些损失函数会在内部处理模型输出的 softmax 和 log softmax 操作。因此,您不需要手动进行这些变换。

负对数似然

负对数似然(Negative Log-Likelihood,简称NLL)是统计学和机器学习中常用的一种损失函数,特别适用于分类问题。它的主要作用是衡量模型预测的概率分布与真实分布之间的差异。数学上,负对数似然损失函数与最大似然估计(Maximum Likelihood Estimation, MLE)紧密相关。

数学定义

对于一个给定的样本 x i x_i xi 及其对应的真实标签 y i y_i yi,假设模型预测的概率分布为 p ( y i ∣ x i ; θ ) p(y_i|x_i; \theta) p(yixi;θ),其中 θ \theta θ 是模型参数。负对数似然损失函数可以定义为:

NLL = − ∑ i = 1 N log ⁡ ( p ( y i ∣ x i ; θ ) ) \text{NLL} = -\sum_{i=1}^{N} \log(p(y_i|x_i; \theta)) NLL=i=1Nlog(p(yixi;θ))

其中, N N N 是样本数量, p ( y i ∣ x i ; θ ) p(y_i|x_i; \theta) p(yixi;θ) 是模型在输入 x i x_i xi 上对真实标签 y i y_i yi 的预测概率。

意义

负对数似然损失函数的主要意义在于它提供了一种度量模型预测概率与真实标签之间匹配程度的方式。具体来说:

  • 匹配程度:如果模型在真实标签上的预测概率越高,负对数似然值越低,表示模型预测得越好。
  • 惩罚错误:如果模型在真实标签上的预测概率很低,负对数似然值会非常大,表示模型预测得很差。

最大似然估计

负对数似然损失函数与最大似然估计有着密切的关系。最大似然估计的目标是找到一组参数 θ \theta θ,使得在给定数据上的似然函数(即样本在当前模型参数下出现的概率)最大化:

θ MLE = arg ⁡ max ⁡ θ ∏ i = 1 N p ( y i ∣ x i ; θ ) \theta_{\text{MLE}} = \arg\max_{\theta} \prod_{i=1}^{N} p(y_i|x_i; \theta) θMLE=argθmaxi=1Np(yixi;θ)

由于对数函数是单调递增的,最大化似然函数等价于最大化对数似然函数:

θ MLE = arg ⁡ max ⁡ θ ∑ i = 1 N log ⁡ ( p ( y i ∣ x i ; θ ) ) \theta_{\text{MLE}} = \arg\max_{\theta} \sum_{i=1}^{N} \log(p(y_i|x_i; \theta)) θMLE=argθmaxi=1Nlog(p(yixi;θ))

进一步,最大化对数似然函数又等价于最小化负对数似然函数:

θ MLE = arg ⁡ min ⁡ θ − ∑ i = 1 N log ⁡ ( p ( y i ∣ x i ; θ ) ) \theta_{\text{MLE}} = \arg\min_{\theta} -\sum_{i=1}^{N} \log(p(y_i|x_i; \theta)) θMLE=argθmini=1Nlog(p(yixi;θ))

负对数似然(NLL)损失函数

基本公式

负对数似然损失函数的基本公式为:

NLL = − ∑ i = 1 N log ⁡ ( p ( y i ∣ x i ; θ ) ) \text{NLL} = -\sum_{i=1}^{N} \log(p(y_i|x_i; \theta)) NLL=i=1Nlog(p(yixi;θ))

其中:

  • N N N 是样本的数量。
  • y i y_i yi 是第 i i i 个样本的真实标签。
  • x i x_i xi 是第 i i i 个样本的输入。
  • p ( y i ∣ x i ; θ ) p(y_i|x_i; \theta) p(yixi;θ) 是模型在给定输入 x i x_i xi 和参数 θ \theta θ 下,对真实标签 y i y_i yi 的预测概率。

逐项解释

  1. 预测概率 p ( y i ∣ x i ; θ ) p(y_i|x_i; \theta) p(yixi;θ):
    这个值是模型在输入 x i x_i xi 下预测标签为 y i y_i yi 的概率。对于分类问题,通常使用 softmax 函数来计算每个类别的概率: p ( y i ∣ x i ; θ ) = exp ⁡ ( z y i ) ∑ j = 1 C exp ⁡ ( z j ) p(y_i|x_i; \theta) = \frac{\exp(z_{y_i})}{\sum_{j=1}^{C} \exp(z_j)} p(yixi;θ)=j=1Cexp(zj)exp(zyi)其中, z j z_j zj 是模型在输入 x i x_i xi 下对于类别 j j j 的原始输出(logits), C C C 是类别数。
  2. 对数概率 log ⁡ ( p ( y i ∣ x i ; θ ) ) \log(p(y_i|x_i; \theta)) log(p(yixi;θ)):
    取对数是为了将概率空间转化为对数空间,便于处理数值稳定性问题,并将乘法关系转换为加法关系,这样计算更加稳定和高效。
  3. 求和 ∑ i = 1 N \sum_{i=1}^{N} i=1N:
    对所有样本的对数概率求和,得到整个批次的总对数似然。
  4. 负号 − - :
    加上负号是因为我们通常希望最大化似然(即找到使预测概率最大的参数),但在优化中,我们通常最小化损失函数。因此,通过取负号,最大化似然问题转换为最小化负对数似然问题。

具体例子

假设我们有一个分类问题,有三个类别,模型输出的是未归一化的分数(logits),例如:

logits = [ 2.0 1.0 0.1 0.5 2.5 0.3 1.2 0.7 1.5 ] \text{logits} = \begin{bmatrix} 2.0 & 1.0 & 0.1 \\ 0.5 & 2.5 & 0.3 \\ 1.2 & 0.7 & 1.5 \end{bmatrix} logits= 2.00.51.21.02.50.70.10.31.5

真实标签为:

labels = [ 0 1 2 ] \text{labels} = \begin{bmatrix} 0 \\ 1 \\ 2 \end{bmatrix} labels= 012

首先,计算 softmax 概率:

probs = softmax ( logits ) \text{probs} = \text{softmax}(\text{logits}) probs=softmax(logits)

softmax 的计算如下:

softmax ( z i ) = exp ⁡ ( z i ) ∑ j = 1 C exp ⁡ ( z j ) \text{softmax}(z_i) = \frac{\exp(z_i)}{\sum_{j=1}^{C} \exp(z_j)} softmax(zi)=j=1Cexp(zj)exp(zi)

对于第一个样本,softmax 概率为:

softmax ( [ 2.0 , 1.0 , 0.1 ] ) = [ exp ⁡ ( 2.0 ) exp ⁡ ( 2.0 ) + exp ⁡ ( 1.0 ) + exp ⁡ ( 0.1 ) , exp ⁡ ( 1.0 ) exp ⁡ ( 2.0 ) + exp ⁡ ( 1.0 ) + exp ⁡ ( 0.1 ) , exp ⁡ ( 0.1 ) exp ⁡ ( 2.0 ) + exp ⁡ ( 1.0 ) + exp ⁡ ( 0.1 ) ] \text{softmax}(\begin{bmatrix} 2.0, 1.0, 0.1 \end{bmatrix}) = \begin{bmatrix} \frac{\exp(2.0)}{\exp(2.0) + \exp(1.0) + \exp(0.1)}, \frac{\exp(1.0)}{\exp(2.0) + \exp(1.0) + \exp(0.1)}, \frac{\exp(0.1)}{\exp(2.0) + \exp(1.0) + \exp(0.1)} \end{bmatrix} softmax([2.0,1.0,0.1])=[exp(2.0)+exp(1.0)+exp(0.1)exp(2.0),exp(2.0)+exp(1.0)+exp(0.1)exp(1.0),exp(2.0)+exp(1.0)+exp(0.1)exp(0.1)]

我们可以计算所有样本的 softmax 概率,然后取对数概率:

log ⁡ ( probs ) = log ⁡ ( softmax ( logits ) ) \log(\text{probs}) = \log(\text{softmax}(\text{logits})) log(probs)=log(softmax(logits))

负对数似然损失为:

NLL = − ∑ i = 1 N log ⁡ ( p ( y i ∣ x i ; θ ) ) \text{NLL} = -\sum_{i=1}^{N} \log(p(y_i|x_i; \theta)) NLL=i=1Nlog(p(yixi;θ))

  • 18
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

西笑生

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值