参考:Pytorch详解NLLLoss和CrossEntropyLoss
来吧,展示!
# _*_ coding:utf-8 _*_
# 参考 https://blog.csdn.net/qq_22210253/article/details/85229988?utm_medium=distribute.pc_relevant_t0.none-task-blog-2%7Edefault%7EBlogCommendFromMachineLearnPai2%7Edefault-1.control&depth_1-utm_source=distribute.pc_relevant_t0.none-task-blog-2%7Edefault%7EBlogCommendFromMachineLearnPai2%7Edefault-1.control
# Pytorch详解NLLLoss和CrossEntropyLoss
import torch
import torch.nn as nn
# # 1 分步骤
# tensor = torch.tensor([[-0.1342, -2.5835, -0.9810],
# [0.1867, -1.4513, -0.3225],
# [0.6272, -0.1120, 0.3048]])
# sm = nn.Softmax(dim=1)
# sm = sm(tensor)
# sm = torch.log(sm)
# print(sm * -1) # tensor([[0.4155, 2.8648, 1.2623],
# # [0.5852, 2.2232, 1.0944],
# # [0.7893, 1.5285, 1.1117]])
# loss = (0.4155 + 1.0944 + 1.5285) / 3 # target [0,2,1]
# print(loss) # 1.0128
# # 2 NLLoss
# tensor = torch.tensor([[-0.1342, -2.5835, -0.9810],
# [0.1867, -1.4513, -0.3225],
# [0.6272, -0.1120, 0.3048]])
#
# loss = nn.NLLLoss()
# sm = nn.Softmax(dim=1)
# sm = sm(tensor)
# sm = torch.log(sm)
# target = torch.tensor([0,2,1])
# loss = loss(sm,target)
# print(loss) # tensor(1.0128)
# 3. crossEntropy
tensor = torch.tensor([[-0.1342, -2.5835, -0.9810],
[0.1867, -1.4513, -0.3225],
[0.6272, -0.1120, 0.3048]])
loss = nn.CrossEntropyLoss()
target = torch.tensor([0,2,1])
loss = loss(tensor, target)
print(loss) # tensor(1.0128)