对于target为类id,其实做了one-hot 操作,例如对于总共有三类,其中id为2,则转换后的标签如下:[0,0,1]。 这些标签作为权重乘上input的值进行叠加。
参见代码,秒懂。 三个输出一致
if __name__ == "__main__":
import torch
import torch.nn as nn
nllloss = nn.NLLLoss()
x = torch.tensor([[1.5,2.5,3.0],
[1.2,2.0,2.9]])
onehot_y = torch.tensor([[0,1.0,0],
[0,0,1]])
logsoft_out = nn.LogSoftmax()(x)
y = torch.tensor([1,2])
print(nllloss(logsoft_out,y))
print(nn.CrossEntropyLoss()(x,onehot_y))
print(nn.CrossEntropyLoss()(x,y))
exit()
如果input和target都是相同维度,例如3x5。
其实做了一个这样的操作,
torch.matmul(input, target.T), 再对这个3x3的矩阵[0,0],[1,1],[2,2]的值累加做平均