import torch
import torch.nn as nn
loss_fn = nn.CrossEntropyLoss()
# 方便理解,此处假设batch_size = 1
x_input = torch.randn(2, 3) # 预测2个对象,每个对象分别属于三个类别分别的概率
# 需要的GT格式为(2)的tensor,其中的值范围必须在0-2(0<value<C-1)之间。
x_target = torch.tensor([0, 2]) # 这里给出两个对象所属的类别标签即可,此处的意思为第一个对象属于第0类,第二个我对象属于第2类
loss = loss_fn(x_input, x_target)
print('loss:\n', loss)
input 需要输入的是logits,logits参考这篇
logits含义
target 不能是ont-hot编码的,直接写
target = tensor([3])
即可,报错因为用了one-hot编码