import torch
import torch.nn.functional as F
def manual_cross_entropy(input, target):
# input: [N, C]
# target: [N, ]
max_logits = torch.max(input, dim=1, keepdim=True).values # [N, 1]
input = input - max_logits # [N, C]
softmax = F.softmax(input, dim=1) # [N, C]
loss = softmax[range(target.size(0)), target] # [N, ]
loss = - loss.log().mean() # []
return loss
logits = torch.tensor([[1.0, 2.0, 0.1], [0.5, 1.5, 2.0]]) # [N, C]
labels = torch.tensor([1, 2]) # [N, ]
loss = manual_cross_entropy(logits, labels)
loss_F = F.cross_entropy(logits, labels)
print('F.cross_entropy的交叉熵损失:', loss_F.item())
print('手动实现的交叉熵损失:', loss.item())