在tensor后面加.item()
即可。
import torch
a = torch.LongTensor([10])
b = a.data # 还是tensor
c = a.item() # int类型,而且只有单个元素才可以用.item()转化
print(b, type(b))
print(c, type(c))
'''输出
tensor([10]) <class 'torch.Tensor'>
10 <class 'int'>
'''
在tensor后面加.item()
即可。
import torch
a = torch.LongTensor([10])
b = a.data # 还是tensor
c = a.item() # int类型,而且只有单个元素才可以用.item()转化
print(b, type(b))
print(c, type(c))
'''输出
tensor([10]) <class 'torch.Tensor'>
10 <class 'int'>
'''