import torch
import numpy as np
data1 = np.array([
[1,2,3],
[2,3,4]
])
data1_torch = torch.from_numpy(data1)
data2 = np.array([
[1,2,3],
[2,3,4]
])
data2_torch = torch.from_numpy(data2)
p = (data1_torch == data2_torch) #对比后相同的值会为1,不同则会为0
print p
print type(p)
d1 = p.sum() #将所有的值相加,得到的仍是tensor类别的int值
print d1
print type(d1)
d2 = d1.item() #转成python数字
print d2
print type(d2)
输出:
tensor([[1, 1, 1],
[1, 1, 1]], dtype=torch.uint8)
<class 'torch.Tensor'>
tensor(6)
<class 'torch.Tensor'>
6
<type 'int'>
从而理解:
correct += (predicted == labels).sum().item() # 如果预测结果和真实值相等则计数 +1