torch中的交叉熵损失函数使用案例
import torch
import torch.nn.functional as F
pred = torch.randn(3, 5)
print(pred.shape)
target = torch.tensor([2, 3, 4]).long() # 需要是整数
print(target.shape)
# 交叉熵损失函数, 输入的参数是形状不一样的
# predict会在其内部进行softmax操作
loss = F.cross_entropy(pred, target)
loss.item()
结果为:
需要注意的是, 传入的参数形状是不同的, predict是softmax之前的, 另外y需要是整形的, int也行