pytorch学习笔记——CrossEntropyLoss、NLLLoss、Softmax和LogSoftmax之间的关系

import torch.nn as nn
import torch

data = torch.Tensor([[-0.2733, 0.3222, 0.2605],
                     [1.5393, 1.1688, -0.0975],
                     [0.3943, 0.5172, -0.9425]])  # 一个3*3的矩阵
print(data)
'''
程序运行的一次结果
tensor([[-0.2733,  0.3222,  0.2605],
        [ 1.5393,  1.1688, -0.0975],
        [ 0.3943,  0.5172, -0.9425]])
'''
sm = nn.Softmax(dim=1)  # 按行 softmax
print(sm(data))
'''
程序运行的一次结果
tensor([[0.2213, 0.4014, 0.3774],
        [0.5305, 0.3663, 0.1032],
        [0.4178, 0.4725, 0.1098]])
'''
print(torch.log(sm(data)))  # log(softmax)
'''
程序运行的一次结果
tensor([[-1.5084, -0.9129, -0.9746],
        [-0.6339, -1.0044, -2.2707],
        [-0.8728, -0.7498, -2.2095]])
'''
slm = nn.LogSoftmax(dim=1)
print(slm(data))  # LogSoftmax
'''
程序运行的一次结果
tensor([[-1.5084, -0.9129, -0.9746],
        [-0.6339, -1.0044, -2.2707],
        [-0.8728, -0.7498, -2.2095]])
'''
# 结论:nn.LogSoftmax = torch.log(nn.Softmax)

loss = nn.NLLLoss()
target1 = torch.tensor([0, 1, 2])  # 随便写一个目标tensor
print(loss(data, target1))  # NLLLoss原始损失
'''
程序运行的一次结果
tensor(0.0157)
(0.2733+0.9425-1.1688)/3 ≈ 0.0157
'''
print(loss(slm(data), target1))  # NLLLoss对LogSoftmax处理后的数据的损失
'''
程序运行的一次结果
tensor(1.5741) 
(1.5084+1.0044+2.2095)/3=1.5741
'''
loss2 = nn.CrossEntropyLoss()
print(loss2(data, target1))  # CrossEntropyLoss损失
'''
程序运行的一次结果
tensor(1.5741)
'''
# 结论:nn.CrossEntropyLoss(input, target1) = nn.NLLLoss(nn.LogSoftmax(input), target1)

target2 = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])  # one-hot 标签
custom_loss = -torch.sum(slm(data) * target2) / 3
print(custom_loss)
'''
程序运行的一次结果
tensor(1.5741)
'''

下面是nn.NLLLoss()函数的公式。
ℓ ( x , y ) = L = l 1 , … , l N ⊤ l n = − w y n x n , y n , w c = w e i g h t [ c ] ⋅ 1 { c ≠ i g n o r e _ i n d e x } ℓ(x,y)=L={l_1,…,l_N}^⊤ \\ l_n=−w_{yn}x_{n,y_n}, w_c=weight[c]⋅1\{c\neq ignore\_index\} \\ (x,y)=L=l1,,lNln=wynxn,yn,wc=weight[c]1{c=ignore_index}
ℓ ( x , y ) = { ∑ n = 1 N l n ∑ n = 1 N w y n , i f r e d u c t i o n = ′ m e a n ′ ; ∑ n = 1 N l n , i f r e d u c t i o n = ′ s u m ′ . ℓ(x,y)=\begin{cases} \sum_{n=1}^{N} \frac{l_n}{\sum^{N}_{n=1}w_{yn}},if \quad reduction='mean';\\ \sum_{n=1}^{N} l_n,\qquad \quad if \quad reduction='sum'. \end{cases} (x,y)={n=1Nn=1Nwynln,ifreduction=mean;n=1Nln,ifreduction=sum.
1.默认情况下weight为1,上述代码中,如果NLLLoss有weight参数,那么weight=torch.Tensor([1, 1, 1]),即上述代码中nn.NLLLoss(weight=torch.Tensor([1, 1, 1]))与nn.NLLLoss()等价。
2.默认情况下,nn.NLLLoss()的reduction为mean。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

xrn1997

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值