使用这个损失函数可以很方便的避免上溢和下溢
自己写的代码:
def cross_entropy(y_hat=torch.tensor(0), y=torch.tensor(0)):
numerator = y_hat[range(len(y_hat)),y]
denominator = torch.log(torch.sum(torch.exp(y_hat),1))
l = -numerator+denominator
return l.sum()/len(l)
调用官方api
loss1 = nn.CrossEntropyLoss()
结果一样
y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
print(cross_entropy(y_hat,y))
print(loss1(y_hat,y))