在Mnist数据集中,我们要得到的输出是0-9,共有十类,这种情况下我们希望输出0-9的概率都大于0,且和为1
是最后一层线性层的输出,
softmax函数:
损失函数:LOSS=-Y*logY_hat
NLLLoss:nagative log likelihood loss
输入一个 y是真实标签,另一个输入要求是softmax之后求对数
实现过程:
在PyTorch中,交叉熵损失全部封装成了Torch.nn.CrossEntropyLoss()
要求y是长整形张量LongTensor([0])表示第几个标签分类,在构造时直接用CrossEntropyLoss,然后计算loss
举例~
代码在pytroch中跑不通不知道为啥,参考了别人的代码
import torch
y = torch.LongTensor([2, 0, 1]) #