import torch
output_1 = torch.tensor([[0.1,0.2],
[0.05,0.4]],dtype=torch.float32)
print(output_1.argmax(0)) # 参数设置为0,通过竖向寻找最大值,并获得最大值对应的索引
output_2 = torch.tensor([[0.1,0.2],
[0.3,0.4]],dtype=torch.float32)
print(output_2.argmax(1)) # 参数设置为1,通过横向寻找最大值,并获得最大值对应的索引
predict = output_2.argmax(1)
target = torch.tensor([0,1])
print((predict == target).sum()) # 统计预测准确的数量
pytorch入门20:准确率的使用
最新推荐文章于 2024-03-21 18:57:56 发布