本文讲述pytorch张量torch.Tensor类型的构建与相互转换以及torch.type()和torch.type_as()的用法
(1) pytorch中张量的定义与类型
pytorch的张量定义有很多种,对于cpu版本有如下七种,对于gpu版本有八种类型,gpu版本的张量只需要在cpu版本的基础上增加cuda就可以了:torch.cuda.DoubleTensor(2,3)
<1>torch.FloatTensor(2,3) 构建一个2*3 Float类型的张量
<2>torch.DoubleTensor(2,3) 构建一个2*3 Double类型的张量
<3>torch.ByteTensor(2,3) 构建一个2*3 Byte类型的张量
<4>torch.CharTensor(2,3) 构建一个2*3 Char类型的张量
<5>torch.ShortTensor(2,3) 构建一个2*3 Short类型的张量
<6>torch.IntTensor(2,3) 构建一个2*3 Int类型的张量
<7>torch.LongTensor(2,3) 构建一个2*3 Long类型的张量
torch.Tensor是默认的tensor类型(torch.FloatTensor)的简称。
(2)张量类型转换
[1]直接转换
tensor = torch.FloatTensor(2,2)
float_tensor = tensor.float()
double_tensor = tensor.double()
long_tensor = tensor.long()
int_tensor = tensor.int()
char_tensor = tensor.char()
byte_tensor = tensor.byte()
short_tensor = tensor.short()
[2]使用tensor.type()
tensor = torch.FloatTensor(2,2)
tensor.type(torch.FloatTensor)
PS:如果只是调用.type(),则返回tensor的类型
[3]使用tensor.type_as()
使用type_as(a)将tensor转化成a的类型
tensor = torch.FloatTensor(2,2)
tensor_ = torch.IntTensor(3,3)
print(tensor.type_as(tensor_))
# 没有将结果覆盖tensor的话,tensor类型保持不变
# 如下结果覆盖
tensor = tensor.type_as(tensor_)
# 类型转换跟tensor的维度没有关系