t = pred.max(-1)[1] == label是什么意思?

t = pred.max(-1)[1] == lab
在学习pytorch的过程中,不止一个人问过这个问题。这行代码常用于统计正确率中,下面让我们一起来探究一下到底是什么逻辑。

def test(pred, lab):
    t = pred.max(-1)[1] == lab
    return torch.mean(t.float())

首先,我们需要确定一点:==的优先级高于=。
上述一行可以写为:

t = (pred.max(-1)[1] == lab)

此外,关于torch.max可以参考:https://blog.csdn.net/ftimes/article/details/118189054

那么pred.max(-1)[1]返回的就是fc输出的两个结果中最大值的维度。
前面提到,我们已经把标签处理成了0和1
pred.max(-1)[1]返回的最大值索引也是0和1(虽然有更好的办法,但这样做也可行)

pred.max(-1)
Out[2]: 
torch.return_types.max(
values=tensor([0.7311, 0.4735, 0.7676, 0.4959, 0.7937, 0.6112, 0.5918, 0.8628, 0.6572,
        0.6233, 0.6227, 0.5958, 0.5774, 0.6474, 0.8323, 0.6375, 0.6602, 0.5261,
        0.4647, 0.4738, 0.6257, 0.5601, 0.7112, 0.6483, 0.4457, 0.7498, 0.7358,
        0.7040, 0.4066, 0.5518, 0.7691, 0.6721, 0.7550, 0.8085, 0.4990, 0.4476,
        0.4958, 0.8220, 0.8659, 0.7823, 0.4055, 0.5893, 0.9285, 0.8180, 0.7205,
        0.7187, 0.6130, 0.6980, 0.6169, 0.6855, 0.4833, 0.6585, 0.7132, 0.6194,
        0.5593, 0.7338, 0.6272, 0.5971, 0.8424, 0.8467, 0.6150, 0.4855, 0.7193,
        0.7658, 0.5708, 0.6497, 0.7794, 0.6412, 0.8045, 0.5714, 0.5182, 0.7838,
        0.7479, 0.6859, 0.7567, 0.4332, 0.5189, 0.7014, 0.5938, 0.7905, 0.7039,
        0.6159, 0.7387, 0.8065, 0.8136, 0.5870, 0.5591, 0.8067, 0.6246, 0.6466,
        0.5937, 0.6113, 0.7880, 0.4537, 0.6373, 0.8012, 0.7067, 0.7080, 0.5979,
        0.8495], grad_fn=<MaxBackward0>),
indices=tensor([0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1,
        1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0,
        0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0,
        1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 1, 0]))

如上,pred.max(-1)[1] == lab得到一个bool张量

pred.max(-1)[1] == lab
Out[4]: 
tensor([ True,  True,  True,  True, False,  True,  True,  True,  True,  True,
         True,  True,  True, False, False, False,  True, False, False, False,
        False,  True,  True, False, False, False,  True,  True,  True,  True,
         True, False,  True, False,  True, False, False,  True,  True,  True,
        False, False, False,  True,  True, False, False,  True,  True, False,
         True,  True, False,  True,  True,  True, False, False,  True, False,
         True,  True,  True, False,  True,  True, False,  True,  True, False,
         True,  True,  True,  True, False,  True, False,  True, False, False,
         True,  True,  True,  True,  True,  True,  True, False,  True,  True,
         True,  True,  True, False,  True,  True,  True,  True,  True,  True])

这就是一个bacth的预测结果(正确与否),也就是我们得到的 t
用t.float()将bool转换为0和1

t.float()
Out[8]: 
tensor([1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 1., 0.,
        0., 0., 0., 1., 1., 0., 0., 0., 1., 1., 1., 1., 1., 0., 1., 0., 1., 0.,
        0., 1., 1., 1., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0., 1., 1., 0., 1.,
        1., 1., 0., 0., 1., 0., 1., 1., 1., 0., 1., 1., 0., 1., 1., 0., 1., 1.,
        1., 1., 0., 1., 0., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1.,
        1., 1., 1., 0., 1., 1., 1., 1., 1., 1.])

计算平均值:

torch.mean(t.float())
Out[9]: tensor(0.6600)

得到该epoch这个bacth的准确率为0.66
这样计算等价于

t.float().sum()/t.numel() #正确的个数除以总个数
Out[12]: tensor(0.6600)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值