t = pred.max(-1)[1] == lab
在学习pytorch的过程中,不止一个人问过这个问题。这行代码常用于统计正确率中,下面让我们一起来探究一下到底是什么逻辑。
def test(pred, lab):
t = pred.max(-1)[1] == lab
return torch.mean(t.float())
首先,我们需要确定一点:==的优先级高于=。
上述一行可以写为:
t = (pred.max(-1)[1] == lab)
此外,关于torch.max可以参考:https://blog.csdn.net/ftimes/article/details/118189054
那么pred.max(-1)[1]返回的就是fc输出的两个结果中最大值的维度。
前面提到,我们已经把标签处理成了0和1
pred.max(-1)[1]返回的最大值索引也是0和1(虽然有更好的办法,但这样做也可行)
pred.max(-1)
Out[2]:
torch.return_types.max(
values=tensor([0.7311, 0.4735, 0.7676, 0.4959, 0.7937, 0.6112, 0.5918, 0.8628, 0.6572,
0.6233, 0.6227, 0.5958, 0.5774, 0.6474, 0.8323, 0.6375, 0.6602, 0.5261,
0.4647, 0.4738, 0.6257, 0.5601, 0.7112, 0.6483, 0.4457, 0.7498, 0.7358,
0.7040, 0.4066, 0.5518, 0.7691, 0.6721, 0.7550, 0.8085, 0.4990, 0.4476,
0.4958, 0.8220, 0.8659, 0.7823, 0.4055, 0.5893, 0.9285, 0.8180, 0.7205,
0.7187, 0.6130, 0.6980, 0.6169, 0.6855, 0.4833, 0.6585, 0.7132, 0.6194,
0.5593, 0.7338, 0.6272, 0.5971, 0.8424, 0.8467, 0.6150, 0.4855, 0.7193,
0.7658, 0.5708, 0.6497, 0.7794, 0.6412, 0.8045, 0.5714, 0.5182, 0.7838,
0.7479, 0.6859, 0.7567, 0.4332, 0.5189, 0.7014, 0.5938, 0.7905, 0.7039,
0.6159, 0.7387, 0.8065, 0.8136, 0.5870, 0.5591, 0.8067, 0.6246, 0.6466,
0.5937, 0.6113, 0.7880, 0.4537, 0.6373, 0.8012, 0.7067, 0.7080, 0.5979,
0.8495], grad_fn=<MaxBackward0>),
indices=tensor([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1,
1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0,
0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 1, 0]))
如上,pred.max(-1)[1] == lab得到一个bool张量
pred.max(-1)[1] == lab
Out[4]:
tensor([ True, True, True, True, False, True, True, True, True, True,
True, True, True, False, False, False, True, False, False, False,
False, True, True, False, False, False, True, True, True, True,
True, False, True, False, True, False, False, True, True, True,
False, False, False, True, True, False, False, True, True, False,
True, True, False, True, True, True, False, False, True, False,
True, True, True, False, True, True, False, True, True, False,
True, True, True, True, False, True, False, True, False, False,
True, True, True, True, True, True, True, False, True, True,
True, True, True, False, True, True, True, True, True, True])
这就是一个bacth的预测结果(正确与否),也就是我们得到的 t
用t.float()将bool转换为0和1
t.float()
Out[8]:
tensor([1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 1., 0.,
0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 1., 1., 1., 0., 1., 0., 1., 0.,
0., 1., 1., 1., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0., 1., 1., 0., 1.,
1., 1., 0., 0., 1., 0., 1., 1., 1., 0., 1., 1., 0., 1., 1., 0., 1., 1.,
1., 1., 0., 1., 0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1.,
1., 1., 1., 0., 1., 1., 1., 1., 1., 1.])
计算平均值:
torch.mean(t.float())
Out[9]: tensor(0.6600)
得到该epoch这个bacth的准确率为0.66
这样计算等价于
t.float().sum()/t.numel() #正确的个数除以总个数
Out[12]: tensor(0.6600)