1、Tensor基本介绍
Tensor是PyTorch中最为重要的一种数据结构。张量是向量、矩阵等更普遍的高阶表达形式,与numpy中的array基本等同,使用方法也非常类似。
2、Tensor使用方法
1. 创建Tensor
import torch
import numpy as np
#定义Tensor
tensor = torch.Tensor([[1,2,3],[4,5,6]])
#定义array
arr = np.array([[1,2,3],[4,5,6]])
以上为定义Tensor和numpy中的array的方法,可以看到两者非常类似。
2. Tensor的属性
Tensor有三个主要属性,分别是 torch.dtype
, torch.device
, torch.layout
。torch.dtype
表示当前矩阵的数据类型,torch.device
表示当前矩阵所在的设备,可以是CPU和GPU。torch.layout
表示当前矩阵的内存布局,可为torch.stried
或torch.sparse_coo
。
3. 数据类型
数据类型 | torch.dtype | np.dtype |
---|---|---|
8 位无符号整型 | torch.uint8 | np.uint8 |
8 位有符号整型 | torch.int8 | np.int8 |
16 位有符号整型 | torch.int16 / torch.short | np.int16 |
32 位有符号整型 | torch.int32 / torch.int | np.int32 |
64 位有符号整型 | torch.int64 / torch.long | np.int64 |
16 位浮点型 | torch.float16 / torch.half | np.float16 |
32 位浮点型 | torch.float32 / torch.float | np.float32 |
64 位浮点型 | torch.float64 / torch.double | np.float64 |
以上为Tensor的主要数据类型以及numpy中的对应类型,其中PyTorch默认的数据类型为torch.float。数据类型转换方法如下:
#将tensor转成uint8的两种方法,其他类型依此类推
tensor2 = tensor.uint8()
tensor2 = tensor.type(torch.uint8())
#numpy中的转换方法
arr2 = arr.astype(np.uint8)
4. 常用函数
#创建全0矩阵
torch.zeros(256,256)
np.zeros((256,256))
#创建全1矩阵
torch.ones(256,256)
np.ones((256,256))
#矩阵拼接
torch.cat((a,b), dim=0)
np.concatenate((a,b), axis=0)
#矩阵降维
torch.squeeze(a, dim=0)
np.squeeze(a, axis=0)
#矩阵升维
torch.unsqueeze(a, dim=0)
np.expand_dims(a, axis=0)
以上为PyTorch和numpy中常用函数的示例,我们可以再次看到,Tensor与array的操作十分类似,熟悉numpy的开发者可以很快掌握Tensor的用法。此后,在其他文章中会详细介绍这些函数的用法。