参考链接:详解torch.nn.NLLLOSS - 知乎分类问题的损失函数中,经常会遇到torch.nn.NLLLOSS。torch.nn.NLLLOSS通常不被独立当作损失函数,而需要和softmax、log等运算组合当作损失函数。 torch.nn.NLLLOSS官方链接: NLLLoss - PyTorch 1.9.0 documentat…
https://zhuanlan.zhihu.com/p/383044774
一、torch.nn.NLLLOSS运算规则
from torch import nn
import torch
# nllloss首先需要初始化
nllloss = nn.NLLLoss() # 可选参数中有 reduction='mean', 'sum', 默认mean
在使用nllloss时,需要有两个张量,一个是预测向量,一个是label
predict = torch.Tensor([[2, 3, 1]]) # shape: (n, category)
label = torch.tensor([1]) # shape: (n,)
- 这里解释一下predict和label,label的shape是n,表示了n个向量对应的正确类别,比如这里label为1,则表明向量(2,3,1)对应的类别是1;
- predict则表示每个类别预测的概率,比如向量(2,3,1)则表示类别0,1,2预测的概率分别为(2,3,1)(先忽略概率大于1的问题)
- predict shape为(1, category)的情况
#
predict = torch.Tensor([[2, 3, 1]])
label = torch.tensor([1])
nllloss(predict, label)
# output: tensor(-3.)
nllloss对两个向量的操作为,将predict中的向量,在label中对应的index取出,并取负号输出。label中为1,则取2,3,1中的第1位3,取负号后输出。
2. predict shape为(n, category)的情况
predict = torch.Tensor([[2, 3, 1],
[3, 7, 9]])
label = torch.tensor([1, 2])
nllloss(predict, label)
# output: tensor(-6)
nllloss对两个向量的操作为,继续将predict中的向量,在label中对应的index取出,并取负号输出。label中为1,则取2,3,1中的第1位3,label第二位为2,则取出3,7,9的第2位9,将两数取平均后加负号后输出。
这时就可以看到最开始的nllloss初始化的时候,如果参数reduction取'mean',就是上述结果。如果reduction取'sum',那么各行取出对应的结果,就是取sum后输出,如下所示:
nllloss = nn.NLLLoss( reduction='sum')
predict = torch.Tensor([[2, 3, 1],
[3, 7, 9]])
label = torch.tensor([1, 2])
nllloss(predict, label)
# output: tensor(-12)
二、与torch.nn.CrossEntropyLoss的区别
torch.nn.CrossEntropyLoss相当于softmax + log + nllloss。
上面的例子中,预测的概率大于1明显不符合预期,可以使用softmax归一,取log后是交叉熵,取负号是为了符合loss越小,预测概率越大。
所以使用nll loss时,可以这样操作
nllloss = nn.NLLLoss()
predict = torch.Tensor([[2, 3, 1],
[3, 7, 9]])
predict = torch.log(torch.softmax(predict, dim=-1))
label = torch.tensor([1, 2])
nllloss(predict, label)
# output: tensor(0.2684)
而使用torch.nn.CrossEntropyLoss可以省去softmax + log
cross_loss = nn.CrossEntropyLoss()
predict = torch.Tensor([[2, 3, 1],
[3, 7, 9]])
label = torch.tensor([1, 2])
cross_loss(predict, label)
# output: tensor(0.2684)
本文详细介绍了PyTorch中的NLLLoss的运算规则及其与CrossEntropyLoss的区别。NLLLoss通常与softmax和log结合使用,CrossEntropyLoss则包含了softmax和log的操作。通过实例展示了NLLLoss的计算过程,并解释了不同reduction参数的影响。同时,对比了直接使用CrossEntropyLoss的便利性。
2885

被折叠的 条评论
为什么被折叠?



