在Mnist数据集中,我们要得到的输出是0-9,共有十类,这种情况下我们希望输出0-9的概率都大于0,且和为1
是最后一层线性层的输出,
softmax函数:
损失函数:LOSS=-Y*logY_hat
NLLLoss:nagative log likelihood loss
输入一个 y是真实标签,另一个输入要求是softmax之后求对数
实现过程:
在PyTorch中,交叉熵损失全部封装成了Torch.nn.CrossEntropyLoss()
要求y是长整形张量LongTensor([0])表示第几个标签分类,在构造时直接用CrossEntropyLoss,然后计算loss
举例~
代码在pytroch中跑不通不知道为啥,参考了别人的代码
import torch
y = torch.LongTensor([2, 0, 1]) # 注意此处是LongTensor
# z_1和z_2是最后一层输出,进入Softmax之前的值,所以每个分类之和不为1
# 每行元素代表对一个对象的分类情况,共三个对象
z_1 = torch.Tensor([[0.1, 0.2, 0.9],
[1.1, 0.1, 0.2],
[0.2, 2.1, 0.1]])
z_2 = torch.Tensor([[0.9, 0.2, 0.1],
[0.1, 0.1, 0.5],
[0.2, 0.1, 0.7]])
criterion = torch.nn.CrossEntropyLoss()
print(criterion(z_1, y), criterion(z_2, y))
结果:
输出为: