准确率原理示例
import torch
output = torch.tensor([[0.1, 0.2],
[0.2, 0.3]])
# 计算准确率:
# 输入预测的各个类别概率
# 取最大概率索引
# 与标签进行对比
# 累加对的个数 / 总数 = 准确率
# 计算矩阵 列(1)行(0) 最大值索引
print(output.argmax(1))
preds = output.argmax(1)
# 假定类型目标值标签, 第0类、第1类
targets = torch.tensor([0, 1])
# 比对
print(preds == targets)
# 累加
a = (preds == targets).sum()
print(a)