原来 y = [[1,0,0],[0,0,1]]
y = torch.tensor([0,2])
y_hat = torch.tensor([[0.1,0.1,0.8],[0.2,0.3,0.5]])
def Cross_Entropy(y_hat, y):
return -torch.log(y_hat[range(len(y)),y])
# y_hat[range(len(y)),y]:
# range(len(y)) 第一维 每行
# y列数 0:第一行 2:第二行
类别个数m,预测的总个数n
y 是种类标签,即第一个的类别是0,第二个的类别是2;
y_hat 不同类别的概率;