这个在label为1维的时候能对的上。
2维测试交叉熵代码:
注意:
output 维度是[batch_size,所分类预测值,样本数]
label维度是[batch_size,样本数]
output = torch.randn(3, 3,5, requires_grad=True)
label = torch.empty((3,5), dtype=torch.long).random_(3)
import torch
import torch.nn as nn
import numpy as np
class CrossEntropyLoss(nn.Module):
def __init__(self):
super(CrossEntropyLoss, self).__init__()
def forward(self, output, label):
if label.dim()>1:
output=output.permute(0,2,1)
label=label.view(-1)
output=output.reshape((label.size(0),output.size(2)))
first = [-output[i][label[i]] for i in range(label.s