常见类型
torch.FloatTensor
torch.DoubleTensor
torch.HalfTensor
torch.ByteTensor
torch.CharTensor
torch.ShortTensor
torch.IntTensor
torch.LongTensor
完整Tensor类型参考:https://pytorch.org/docs/stable/tensors.html
查看类型
使用.type()
方法可以查看,例如:
import torch
x = torch.LongTensor([1, 2, 3, 4, 5, 6, 8])
print(x.type()) # torch.LongTensor
转换到其他类型
方法一:使用type(torch.xxxTensor)
比如:
import torch
x = torch.LongTensor([1, 2, 3, 4, 5, 6, 8])
print(x.type()) # torch.LongTensor
print(x.type(torch.FloatTensor).type()) # torch.FloatTensor
方法二:使用内置函数
比如:
import torch
x = torch.LongTensor([1, 2, 3, 4, 5, 6, 8])
print(x.type()) # torch.LongTensor
print(x.float().type()) # torch.FloatTensor
内置函数常见.float()
,.int()
,.long()
,.byte()
等