仅作为记录,大佬请跳过。
背景
output是每一类对应一个概率值
label是其中一类
nn.CrossEntropyLoss是怎么计算
代码
import torch
import torch.nn as nn
import math
criterion = nn.CrossEntropyLoss()
output = torch.randn(1, 5, requires_grad=True)
label = torch.empty(1, dtype=torch.long).random_(5)
loss = criterion(output, label)
# print("网络输出为5类:")
print(output)
# print("要计算label的类别:")
print(label)
# print("计算loss的结果:")
print(loss)
first = 0
for i in range(1):
first = -output[i][label[i]]
print(first)
second = 0
for i in range(1):
for j in range(5):
second += math.exp(output[i][j])
res = 0
res = (first + math.log(second))
print("自己的计算结果:")
print(res)
math.exp(0.4045)+math.exp(-1.2018)+math.exp(-0.0459)+math.exp(1.3131)+math.exp(0.3205)
math.log(math.exp(0.4045)+math.exp(-1.2018)+math.exp(-0.0459)+math.exp(1.3131)+math.exp(0.3205))
-0.4045+math.log(math.exp(0.4045)+math.exp(-1.2018)+math.exp(-0.0459)+math.exp(1.3131)+math.exp(0.3205))
参考
感谢大佬博主文章传送门