torch.nn中NLLLoss与CrossEntropyLoss比较详解

本篇文章中我们将详细比较torch.nn中两个损失函数类NLLLoss与CrossEntropyLoss,首先我们将介绍负对数似然和交叉熵,其次我们再介绍在Pytorch中两个类具体的执行计算方式。

数学推导

我们来考虑一个 n n n分类问题, 为了使讨论更为简洁,我们这里只考虑一个样本(sample),输入为 x \boldsymbol{x} x ,经模型输出为 l o g i t s logits logits,经过Softmax归一化后预测概率分布为 y ^ = S o f t m a x ( l o g i t s ) = [ p 1 , p 2 , … , p n ] T \hat{y}=Softmax(logits)=[p_1,p_2,\dots,p_n]^T y^=Softmax(logits)=[p1,p2,,pn]T,真实标签为 y \boldsymbol{y} y,假设该样本实际上属于第 c c c类,即 y = [ y 1 , y 2 , … , y n ] T = [ 0 , 0 , … , 1 , … , 0 ] T \boldsymbol{y}=[y_1,y_2,\dots,y_n]^T=[0,0,\dots,1,\dots,0]^T y=[y1,y2,,yn]T=[0,0,,1,,0]T为one-hot向量。
我们想要最大化样本属于真实类别 c c c的概率, 即最小化负对数似然(negetive log likelihood)
N L L = − L o g P ( y ^ ∣ x ) = − l o g p c (1) \begin{aligned} NLL &= -LogP(\hat{y}|x)=-logp_c \tag{1} \end{aligned} NLL=LogP(y^x)=logpc(1)
另外要注意深度学习中 l o g log log函数往往指的是 l n ln ln函数,即自然对数。
而我们知道 y \boldsymbol{y} y为one-hot向量,只有第 c c c维位置为1,故
N L L = − l o g p c = − 1 ⋅ l o g p c = − ( 0 ⋅ l o g p 1 + 0 ⋅ l o g p 2 + ⋯ + 1 ⋅ l o g p c + ⋯ + 0 ⋅ l o g p n ) = − ∑ i = 1 n y i l o g p i = − y ⋅ l o g y ^ (2) \begin{aligned} NLL &= -logp_c \\ &= -1\cdot logp_c \\ &=-(0\cdot logp_1+0\cdot logp_2+\dots+1\cdot logp_c+\dots+0\cdot logp_n) \\ &= -\sum\limits_{i=1}^ny_ilogp_i \\ &= -\boldsymbol{y}\cdot log\hat{\textbf{y}} \tag{2} \end{aligned} NLL=logpc=1logpc=(0logp1+0logp2++1logpc++0logpn)=i=1nyilogpi=ylogy^(2)
最后结果即为交叉熵(Cross Entropy)
C E = − y ⋅ l o g y ^ (3) CE = -\boldsymbol{y}\cdot log\hat{\textbf{y}} \tag{3} CE=ylogy^(3)
所以对于n分类问题,两者是等价的。

代码实践

但事实上在Pytorch中,具体的执行计算方式有所不同。
由公式(2)我们可得到
C E = − y ⋅ L o g S o f t m a x ( l o g i t s ) (4) \begin{aligned} CE &= -\boldsymbol{y}\cdot LogSoftmax(logits) \tag{4} \end{aligned} CE=yLogSoftmax(logits)(4)
而CrossEntropyLoss()事实上是对logits进行LogSoftmax计算交叉熵,但是NLLLoss()并没有这一步,需要对模型输出的logits外加LogSoftmax操作。
下面我们通过代码演示来展示在Pytorch框架中两种损失函数的实际应用区别。

import torch.nn as nn
import torch.nn.functional as F
nnl = nn.NLLLoss()
ce = nn.CrossEntropyLoss()
ls = nn.LogSoftmax(dim=-1)
logits = torch.rand(3)
target = torch.tensor(1)
print(logits)
loss1 = nnl(ls(logits), target)
loss2 = ce(logits, target)
print(loss1)
print(loss2)
# output
#tensor([0.0437, 0.1241, 0.2193])
#tensor(1.1061)
#tensor(1.1061)

所以我们最终可以总结为
n n . L o g S o f t m a x ( ) & n n . N L L L o s s ( ) ⇔ n n . C r o s s E n t r o p y L o s s ( ) \textcolor{red} {nn.LogSoftmax() \& nn.NLLLoss() \quad \Leftrightarrow \quad nn.CrossEntropyLoss()} nn.LogSoftmax()&nn.NLLLoss()nn.CrossEntropyLoss()

  • 19
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
nn.NLLLoss是一个PyTorch的损失函数,用于计算负对数似然损失。在神经网络,它通常与nn.LogSoftmax结合使用,用于多分类问题。NLLLoss的计算步骤如下: 1. 输入的形状为[batch_size, num_classes]的张量,其每一行代表一个样本的预测概率分布。 2. 需要提供一个形状为[batch_size]的目标标签张量。 3. 首先,通过使用nn.LogSoftmax函数对输入进行log softmax操作,得到每个类别的对数概率。 4. 接下来,根据目标标签从对数概率选择相应的概率。 5. 最后,将选择的概率取负并求和,得到最终的损失值。 在上述代码示例,我们可以看到使用nn.NLLLoss计算了输入和目标之间的损失。输出的值为tensor(2.1280),grad_fn=<NllLossBackward0>,其grad_fn表示反向传播函数。 <span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *3* [深度学习方法——NLLloss简单概括](https://blog.csdn.net/qq_50571974/article/details/124314082)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] - *2* [Pytorch损失函数nn.NLLLoss2d()用法说明](https://download.csdn.net/download/weixin_38536397/14841223)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值