【损失函数】(三) NLLLoss原理 & pytorch代码解析

1.简介

接下来应该要介绍分类任务里最常用的CrossEntropyLoss,但是CrossEntropyLoss里面用到了NLLLoss,所以我们先了解NLLLoss是什么。


2.NLLloss

NLLLoss全名叫做Negative Log Likelihood Loss,顾名思义,输入就是log likelihood,对输入的对应部分取负号就是这个loss的输出了,公式可以表示为:

loss(x,y)=L=\left \{ l_{1},...,l_{N} \right \}^{T},l_{n}=-w_{y_{n}}x_{n,y_{n}}

其中w_{y_{n}}是每个类别的权重,默认的全为1,x_{n,y_{n}}表示对应target那一类的概率。更简单一点来说,就是将对应类别的x输出取负值。比如说x=[-2, -3, -0.5],y=[0, 0, 1],那么NLL loss就为对应target那类的输出值的相反数,所以

loss(x,y)=-(-0.5)=0.5


3.思考

那么看到这个公式第一个疑问:loss会出现负数的情况吗?这是不会发生的。因为NLL设定就是用于log likelihood的,所以用来计算loss的输出x是经过softmax再加log后得到的结果。softmax后范围在0-1之间,log之后范围就在-inf-0之间,所以loss肯定为非负数。

接下来是第二个疑问:这个损失函数为什么要这么设计呢?个人认为它是将多分类问题简单化处理,通过softmax将输出加权后使其加起来为1,这样的话只计算一个类别的loss的同时也考虑到了其他概率的影响,不需要像mse那样每个类别都进行计算。实际上softmax+log+NLLloss也就是另外一个常用的损失函数CrossEntropyLoss。关于CrossEntropyLoss具体的介绍会在下一篇文章里说明。

最后一个疑问:可不可以不要log呢?因为实际上softmax之后从数学层面上就已经满足了loss的计算逻辑,加上log只是为了更加符合交叉熵公式的计算方法。所以要去掉log肯定也是可以的,只是交叉熵更加通用一点。还有一种说法是加了log后会让每个点的梯度不一样,log函数越接近0的点梯度越大,损失越大,更新的速度也越快,这样可以加快模型的收敛。


4.pytorch代码

以下代码为pytorch官方NLLloss代码,可以看到里面有几个参数,我们大多数情况下使用默认参数设置就好。

torch.nn.NLLLoss(weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean')

其中:

  • weight表示每个类别的权重,当标签不平衡的时候可以使用来防止过拟合。
  • size_average表示是否将样本的loss进行平均之后输出,默认为true。
  • ignore_index表示忽略某一类别,不想训练某些类别时可用。
  • reduce表示是否将输出进行压缩,默认为true。当它为false的时候就会无视size_average。
  • reduction表示用怎么的方法进行reduce。可以设置为'none','mean','sum'。

其中target为对应类别的真值索引:

import torch
import torch.nn as nn

a = torch.randn(3, 5)
b = torch.Tensor([0, 4, 1]).long()


criterion = nn.NLLLoss()
c = criterion(torch.log(torch.softmax(a, dim=1)), b)
print(c)


业务合作/学习交流+v:lizhiTechnology

 如果想要了解更多损失函数相关知识,可以参考我的专栏和其他相关文章:

损失函数_Lcm_Tech的博客-CSDN博客

【损失函数】(一) L1Loss原理 & pytorch代码解析_l1 loss-CSDN博客

【损失函数】(二) L2Loss原理 & pytorch代码解析_l2 loss-CSDN博客

【损失函数】(三) NLLLoss原理 & pytorch代码解析_nll_loss-CSDN博客

【损失函数】(四) CrossEntropyLoss原理 & pytorch代码解析_crossentropyloss 权重-CSDN博客

【损失函数】(五) BCELoss原理 & pytorch代码解析_bce损失函数源码解析-CSDN博客

如果想要了解更多深度学习相关知识,可以参考我的其他文章:

深度学习_Lcm_Tech的博客-CSDN博客

【优化器】(一) SGD原理 & pytorch代码解析_sgd优化器-CSDN博客

【图像生成】(一) DNN 原理 & pytorch代码实例_pytorch dnn代码-CSDN博客

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值