测试代码:
import torch import torch.nn as nn import math loss = nn.CrossEntropyLoss() input = torch.randn(1, 5, requires_grad=True) target = torch.empty(1, dtype=torch.long).random_(5) output = loss(input, target) print("输入为5类:") print(input) print("要计算loss