cross entropy loss
具体交叉熵的理论网络上有很多,这里就看下pytorch内部的计算方式
import torch
import torch.nn as nn
# 网络的输出数据,数据取自深度之眼pytorch入门lesson-15人民币二分类
# BATCH改为5,CLASS = 2
out = torch.tensor(
[[-0.0223, 0.2420],
[0.1782, 0.6221],
[0.0887, 0.4575],
[0.3041, 0.3169],
[0.1052, 0.0649]])
# 数据标签
label = torch.tensor([1, 0, 1, 0, 1])
print(f"网络输出out = {out}")
print("0 - 比较手动log, softmax和LogSoftmax")
sm = nn.Softmax(dim=1)
print(f"经过Softmax = {sm(out)}")
print(f"再做Log = {torch.log(sm(out))}")
lsm = nn.LogSoftmax(dim=1)
lsm_result = lsm(out)
print(f"直接使用LogSoftmax = {lsm_result}")
print("\n")
print("1 - 使用NLLLoss")
loss = nn.NLLLoss()
print(f"NLLLoss = {loss(lsm_result, label)}")
print("\n")
print("2 - 手工计算loss")
ce = [lsm_result[index,i.item()].item() for index,i in enumerate(label)]
print(ce)
ce_tensor = torch.tensor(ce)
print(ce_tensor)
print(f"手工计算Loss = {-ce_tensor.mean()}")
print("\n")
print("3 - 使用CrossEntropyLoss")
loss = nn.CrossEntropyLoss()
print(f"使用CrossEntropyLoss = {loss(out, label)}")