报错代码
def cross_entropy(y_hat, y):
tmp = torch.arange(y_hat.size(0))
print(range(len(y_hat)))
return -torch.log(y_hat[range(len(y_hat)), y])
y_hat = torch.arange(1, 21, 1).reshape(-1, 2)
y = torch.zeros(10)
print(y_hat.size(0))
print(cross_entropy(y_hat, y))
报错
IndexError: tensors used as indices must be long, int, byte or bool tensors
将y_hat[range(len(y_hat)), y]中的y强转为int即可
y_hat[range(len(y_hat)), y.type(torch.int)]