import torch
outputs = torch.tensor([[0.1, 0.2],
[0.05, 0.4]])
print(outputs.argmax(1)) # 里面的1是一个方向,代表矩阵横向
# 输出tensor([1, 1]) 说明0.2在第一行最大,0.4在第二行最大
preds = outputs.argmax(1)
targets = torch.tensor([0, 1])
print(preds == targets)
# 因为tensor([1, 1])和tensor([0, 1])
# 所以输出tensor([False, True])
print((preds == targets).sum())
# 输出tensor(1),这是对应位置相等的个数
print(outputs.argmax(0)) # 里面的1是一个方向,代表矩阵的列
# 输出tensor([0, 1]) 说明0.1在第一列最大,0.4在第二列最大
argmax()和.sum()用法
最新推荐文章于 2024-04-27 12:07:25 发布