在计算Loss和accuracy的时候经常会遇到.item()
其作用是取出tensor中的值,变为Python的数据类型。
import torch
x = torch.randn(2,2)
print(x)
print(x[1,1])
输出:
tensor([[-1.3722, 1.1144],
[-0.8714, 0.3114]])
tensor(0.3114)
=========================================
import torch
x = torch.randn(2,2)
print(x)
print(x[1,1].item())
输出:
tensor([[-0.3186, 0.6605],
[-0.2433, -0.7404]])
-0.7403690814971924