input.argmax(1):横向比较
input.argmax(0):纵向比较
import torch
input = torch.tensor([[0.6,0.5,0.3],
[0.2,0.3,0.1]])
target1 = torch.tensor([0,1])
target2 = torch.tensor([0,0,1])
# 参数为1,横向比较
output1 = input.argmax(1)
print(output1) # torch.tensor([0,1])
print((output1 == target1).sum()) # 输出预测正确次数:tensor(2)
# 参数为0,纵向比较
output2 = input.argmax(0)
print(output2) # torch.tensor([0,0,0])
print((output2 == target2).sum()) # 输出预测正确次数:tensor(2)