PyTorch系列 | correct += (predicted == labels).sum().item()的理解

各位小伙伴肯定看到过下面这段代码:

correct += (predicted == labels).sum().item()

这里面(predicted == labels)是布尔型,为什么可以接sum()呢?

我做了个测试,如果这里的predicted和labels是列表形式就会报错,如果是numpy的数组格式,会返回一个值,如果是tensor形式,就会返回一个张量。

举个例子:

import torch

a = torch.tensor([1,2,3])
b = torch.tensor([1,3,2])

print((a == b).sum())

上述代码的输出结果:

tensor(1)

如果将a和b改成numpy下的数组格式:

import numpy as np

a = np.array([1,2,3])
b = np.array([1,3,2])

print((a == b).sum())

上述代码的输出结果:

1

如果将a和b改成列表:

a = [1,2,3]
b = [1,3,2]

print((a == b).sum())

上述代码的输出结果:

Traceback (most recent call last):
  File "路径", line 4, in <module>
    print((a == b).sum())
AttributeError: 'bool' object has no attribute 'sum'

Process finished with exit code 1

Added:

.item()用于取出tensor中的值。

评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值