基本数据:Tensor
Tensor,即张量,是pytorch中的基本操作对象。
Tensor数据类型
数据类型 | CPU Tensor | GPU Tensor |
---|---|---|
32位浮点 | torch.FloatTensor | torch.cuda.FloatTensor |
64位浮点 | torch.DoubleTensor | torch.cuda.DoubleTensor |
16位半精度浮点 | N/A | torch.cuda.HalfTensor |
8位无符号整型 | torch.ByteTensor | torch.cuda.ByteTensor |
8位有符号整型 | torch.CharTensor | torch.cuda.CharTensor |
16位有符号整型 | torch.ShortTensor | torch.cuda.ShortTensor |
32位有符号整型 | torch.IntTensor | torch.cuda.IntTensor |
64位有符号整型 | torch.LongTensor | torch.cuda.LongTensor |
pytorch可以通过set_default_tensor_type()
函数设置默认使用的Tersor类型,在局部用完后如果需要其他类型则需要返回重新设置。
torch.set_default_tensor_type('torch.DoubleTensor')
对于Tensor的类型转换可以使用type(new_type)
,type_as()
,int()
等多种方式。
Tensor的创建与维度查看
- 基础Tensor函数:
torch.Tensor(2,2)
- 指定类型:
torch.DoubleTensor(2,2)
- 使用python的list序列:
torch.Tensor([[1,2],[3,4]])
- 默认值为0:
torch.zeros(2,2)
- 默认值为1:
torch.ones(2,2)
- 对角张量:
torch.eye(2,2)
- 随机张量:
torch.randn(2,2)
- 随机排列张量:
torch.randperm(4)
对于维度,可以使用Tensor.shape
或者方法size()
进行查看。
查看Tensor中元素总个数,可使用Tensor.numel()
。
Tensor的组合和分块
组合是吧不同的Tensor叠加起来,主要有torch.cat()
和torch.stack()
两个函数。前者为沿着某一维度进行拼接,后者为新增维度。