同样可以参考Pytorch中Softmax、Log_Softmax、NLLLoss以及CrossEntropyLoss的关系与区别详解_NeilPy的博客-CSDN博客
参考以上公式我们进行如下计算:
import torch.nn as nn
import torch
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
print(input)
target = torch.empty(3, dtype=torch.long).random_(5)
print(target)
output = loss(input, target)
print(output)
output.backward()
程序结果:
input tensor([[ 0.8613, 0.2848, -0.9878, 1.6137, 1.6703],
[-0.5740, 0.6567, -0.7853, -1.5065, 1.3024],
[-1.2544, -0.7814, 0.0204, -0.7491, 0.3055]], requires_grad=True)
tensor([0, 2, 3])
loss:tensor(2.1812, grad_fn=<NllLossBackward>)
计算:
loss = -x[class]+log(exp(x[i])累加)
loss1=1.8061534287700232
listnum=[ 0.8613, 0.2848, -0.9878, 1.6137, 1.6703]
sum=0
for num in listnum:
sum = sum+math.exp(num)
print(-listnum[0]+math.log(sum))
loss2=2.709178782095951
listnum=[-0.5740, 0.6567, -0.7853, -1.5065, 1.3024]
sum=0
for num in listnum:
sum = sum+math.exp(num)
print(-listnum[2]+math.log(sum))
loss2=-listnum[2]+math.log(sum)
loss3=2.028286903303233
listnum=[-1.2544, -0.7814, 0.0204, -0.7491, 0.3055]
sum=0
for num in listnum:
sum = sum+math.exp(num)
print(-listnum[3]+math.log(sum))
loss3=-listnum[3]+math.log(sum)
loss=(loss1+loss2+loss3)/3=2.181206371389736 和上面结果一样