前言
本文内容基于pytorch 1.6版本进行学习,记录下pytorch在实际场景中常用的API以备编程时查阅。
本文大量参考借鉴了B站视频《PyTorch学这个就够了》
pytorch开发环境配置请参考pytorch安装 CUDA安装
1. 创建张量(tensor)
API | 解释 | 共享内存 |
---|---|---|
a.shape a.size() | 返回张量a的形状,返回值为张量类型 | |
torch.from_numpy(numpy数组) | 返回numpy数组转化而来的张量,内存共享 | √ |
torch.tensor([1,2,3,4,5]) | 定义张量,数值指定,深拷贝 | × |
torch.FloatTensor([1,2,3]) | 返回Float型张量,深拷贝 | × |
torch.FloatTensor(4,3,6) | 返回4×3×6大小的Float型张量,数值随机 | |
torch.Tensor(4,3,6) | 同上,默认FloatTensor,可使用下一条语句修改。 | |
torch.set_default_tensor_type(d=) | 修改torch 环境默认Tensor类型,如torch.DoubleTensor ,只能填实数类型 | |
torch.ones(3,4) | 返回3×4的全1张量 | |
torch.zeros(3,4) | 返回3×4的全0张量 | |
torch.eye(3,3) | 返回3×3的单位矩阵,只能是二维矩阵 | |
torch.full([3,4],7,dtype=int) | 返回3×4的全为7的张量,且类型为int | |
torch.arange(0,5) 或torch.arange(0,10,2) | 前者返回tensor([0,1,2,3,4]),后者返回tensor([0,2,4,6,8]) | |
torch.linspace(0,10,steps=4) | 返回一维tensor([0.00, 3.33, 6.67, 10.00]),长度为steps,间距相等 | |
torch.logspace(0,1,steps=4) | 返回一维tensor([1.00, 2.15, 4.64, 10.00]), f i = 1 0 x f_i=10^x fi=10x, x x x等距分布在[0,1] | |
torch.randperm(10) | 返回0到9这10个整数的一维乱序张量 | |
torch.rand(3,4) | 返回3×4张量,数值服从[0,1]均匀分布 | |
torch.randn(3,4) | 返回3×4张量,数值服从标准正态分布 | |
torch.rand_like(a) | 返回与a同shape的张量,数值服从[0,1]均匀分布 | |
torch.randint(1,10, [3,4]) | 返回3×4的张量,数值从[1,10)随机选取 | |
torch.normal(mean=,std=) | 返回服从正态分布的张量。mean和std均为张量(形状一致),输出张量每个元素服从每一对(mean,std)正态分布 |
2. 比较大小 & bool张量快速判断
以下5个函数,如果x.shape=(3,4,5)
,则y
必须是【一个数字
,y.shape=(5)
,y.shape=(4, 5)
,y.shape=(3, 4, 5)
】其中之一。返回值为与x
同形状的bool
型张量。
API | 解释 |
---|---|
x.ge(y) | 等价于x>=y |
x.le(y) | 等价于x<=y |
x.gt(y) | 等价于x>y |
x.lt(y) | 等价于x<y |
x.eq(y) | 等价于x==y |
以下函数返回值为True
或False
API | 解释 |
---|---|
x.equal(y) | 张量
x
x
x与
y
y
y必须同形状。
x
x
x与
y
y
y对应元素相等时返回True |
torch.all(mask) | 若bool张量mask 全为True ,返回True ;否则返回False |
torch.any(mask) | 若bool张量mask 存在True ,返回True ;否则返回False |
3. 索引与切片
假设张量a=torch.rand(4,3,28,28)
,可以表示4张图片,每个图片3个channel,每个channel大小28×28
API | 解释 | 内存共享 |
---|---|---|
a[0,0,2,4] | 每个维度指定索引,返回确切的元素值 | |
a[:2] | 仅在第0个维度上,截取第0到2(不含)张照片,形状为(2,3,28,28) ,内存共享 | √ |
a[0:2, :, :, -1:] | 第0个维度同上,第1,2维度上全取,第3维度上取最后一个。中间连续: 可用... 代替 | √ |
a[:, :, 0:10:3] | 第0,1维度上全选。第2维度上从0到10步长为3选取,效果同range(0,10,3) | √ |
a.index_select(0, torch.tensor([0,2])) | 第0维度上,选择第0,2张图片。内存不共享,返回复制体,下同 | × |
a.masked_select(mask) a[mask] | bool张量mask 与a 同形状,抽取a 中对应mask 位置为True 的元素,返回一维张量 | × |
a.take(indices) | 把a 看作一维,抽出一维张量indices 指定位置的元素,返回一维张量 |
4. 维度变换
假设张量a
的形状为(1,4,3,1)
,b
的形状为(2,3)
API | 解释 | 内存共享 |
---|---|---|
a.view(3,4) | 返回张量a 形状为3×4的形式,参数填-1则自动计算 | √ |
a.squeeze(0) | 若第0维度是1 ,则删掉这个维度;否则不变。实际数据不变 | √ |
a.unsqueeze(0) | 与squeeze 相反,增加第0维度,大小是1 | √ |
a.expand(4,-1,-1,-1) , | 第0维度“复制”为4份,其他维度不变。“复制”只能是大小为1 的维度 | √ |
a.repeat(4,-1,-1,-1) | 功能同expand 。但repeat 是深拷贝 | × |
转置操作后,原内存中的数据会发生变化,不再连续,可执行contiguous()
获得连续形式。
API | 解释 | 内存共享 |
---|---|---|
b.t() | 返回二维张量b 的转置 | √ |
a.transpose(0,2) | 在第0和2维度上进行转置;只能在2个维度上转置 | √ |
a.permute(2,1,0,3) | 多维度转置。如第2维度转到第0维度;必须填满所有维度 | √ |
5. 拼接与拆分
假设张量a
的形状为(4,3,2)
,b
的形状为(4,3,2)
。
API | 解释 | 内存共享 |
---|---|---|
torch.cat([a,b],dim=0) | 在第0维度上拼接a 与b ,返回形状(8,3,2) 。需保证除第dim 维外,形状相同 | × |
torch.stack([a,b],dim=1) | 在第1维度上组合a 与b ,返回形状(4,2,3,2) 。需保证a 与b 形状完全相同 | × |
a.split(split_sizes=1,dim=0) | 在第dim 维度上拆分,每份大小为split_sizes 。与cat() 互为反操作,但与a 共享内存 | √ |
a.split([1,2,1],dim=0) | 在第dim 维度上拆分,每份大小对应split_sizes 每个元素 | √ |
a.chunk(chunks=4,dim=3) | 在第dim 维度上拆分,平均分为chunks 份 | √ |
6. 基本运算
下表运算中,有两种变形:1.torch.*()
,如torch.add(a, b)
等价于a.add(b)
;2.a.*_()
,如a.add_(b)
表示a
本身的值会被同时修改为a+b
。
API | 解释 | 内存共享 |
---|---|---|
a.add(b) a+b | 对应元素相加。a.add_(b) 则a本身会被修改为a+b 的值,下面的所有函数都是。 | |
a.sub(b) a-b | 对应元素相减 | |
a.mul(b) a*b | 对应元素相乘 | |
a.div(b) a/b | 其中b 不能是整数类型 | |
x.dot(y) | 点乘(内积),x 和y 必须都是一维(即向量) | |
x.mm(y) | 矩阵乘法,x 和y 必须均为二维(即矩阵) | |
x.matmul(y) x@y | 矩阵乘法,兼容dot() 和mm() | |
a.pow(2) | 等价于a**2 求张量每个元素的n次方 | |
a.sqrt() | 求张量每个元素的平方根。square root | |
a.rsqrt() | 在sqrt() 基础上对每个元素取倒数 | |
a.exp() | 对每个元素 x x x,求 e x e^x ex | |
a.log() | 对每个元素 x x x,求 l n x ln x lnx,即以 e e e为底的对数 | |
a.floor() | 对每个元素向下取整 | |
a.ceil() | 对每个元素向上取整 | |
a.trunc() | 对每个元素取整数部分 | |
a.frac() | 对每个元素取小数部分 | |
a.round() | 对每个元素四舍五入,返回类型为torch.FloatTensor | |
a.clamp(1,3) | 小于1的元素改为1,大于3的元素改为3。第二参数不填则不限制。内存深拷贝 | × |
7. 数据统计
下表函数均含有默认参数dim=None, keepdim=False
,返回值分2种情况:
dim=None
,表示对所有元素进行操作,返回一个标量。dim=i
,在第i
维度上操作,同时可指定keepdim=True
是否保持原张量的维数,返回一个元组(values,indices)
,values
是值张量,indices
是值所在的下标组成的张量。
API | 解释 |
---|---|
a.max() | 最大值 |
a.min() | 最小值 |
a.median() | 中位数,偶数个元素时取左中值。max()返回值为max对象,取其values成员才是tensor!min(),median()同理 |
a.sum() | 求和 |
a.mean() | 均值 |
a.prod() | 累乘 |
a.argmax() | 最大值的索引,多个最大值时返回最大索引 |
a.argmin() | 最小值的索引,多个最小值时返回最大索引 |
a.norm(p=,dim=None,keepdim=False) | 求张量第dim 维的第p 范数,即
∑
i
x
i
p
p
\sqrt[p]{\sum_i x_i^p}
p∑ixip |
以下函数默认dim=None
但实际上运算时默认dim=-1
,即最低维上。故一定返回一个元组(values,indices)
API | 解释 |
---|---|
b.topk(k=, dim=None, largest=True, sorted=True) | 在第dim 维度上求前k大(largest=False 则前k小) |
b.kthvalue(k=,dim=None,keepdim=False) | 在第dim 维度上求第k小 |
8. 复杂操作
API | 解释 |
---|---|
torch.where(condition,x,y) | 三个参数为同形状张量, o u t ( i ) = { x i , i f ( c o n d i t i o n i i s T r u e ) y i , o t h e r w i s e out(i)=\begin{cases}x_i, & if (condition_i \; is \; True) \\ y_i, & otherwise\end{cases} out(i)={xi,yi,if(conditioniisTrue)otherwise |
torch.gather(a,dim=,index=) | 张量a 和index 的形状必须在同一维度空间中。从第dim 维度看index ,抽取a 中第dim 维度对应的那个位置的元素。返回张量与index 形状相同 |