各位小伙伴肯定看到过下面这段代码:
correct += (predicted == labels).sum().item()
这里面(predicted == labels)
是布尔型,为什么可以接sum()呢?
我做了个测试,如果这里的predicted和labels是列表形式就会报错,如果是numpy的数组格式,会返回一个值,如果是tensor形式,就会返回一个张量。
举个例子:
import torch
a = torch.tensor([1,2,3])
b = torch.tensor([1,3,2])
print((a == b).sum())
上述代码的输出结果:
tensor(1)
如果将a和b改成numpy下的数组格式:
import numpy as np
a = np.array([1,2,3])
b = np.array([1,3,2])
print((a == b).sum())
上述代码的输出结果:
1
如果将a和b改成列表:
a = [1,2,3]
b = [1,3,2]
print((a == b).sum())
上述代码的输出结果:
Traceback (most recent call last):
File "路径", line 4, in <module>
print((a == b).sum())
AttributeError: 'bool' object has no attribute 'sum'
Process finished with exit code 1
Added:
.item()用于取出tensor中的值。