pytorch常用函数 API学习笔记教程 快速查表

前言

本文内容基于pytorch 1.6版本进行学习,记录下pytorch在实际场景中常用的API以备编程时查阅。

本文大量参考借鉴了B站视频《PyTorch学这个就够了》

pytorch开发环境配置请参考pytorch安装 CUDA安装

1. 创建张量(tensor)

API解释共享内存
a.shape a.size()返回张量a的形状,返回值为张量类型
torch.from_numpy(numpy数组)返回numpy数组转化而来的张量,内存共享
torch.tensor([1,2,3,4,5])定义张量,数值指定,深拷贝×
torch.FloatTensor([1,2,3])返回Float型张量,深拷贝×
torch.FloatTensor(4,3,6)返回4×3×6大小的Float型张量,数值随机
torch.Tensor(4,3,6)同上,默认FloatTensor,可使用下一条语句修改。
torch.set_default_tensor_type(d=)修改torch环境默认Tensor类型,如torch.DoubleTensor只能填实数类型
torch.ones(3,4)返回3×4的全1张量
torch.zeros(3,4)返回3×4的全0张量
torch.eye(3,3)返回3×3的单位矩阵,只能是二维矩阵
torch.full([3,4],7,dtype=int)返回3×4的全为7的张量,且类型为int
torch.arange(0,5)torch.arange(0,10,2)前者返回tensor([0,1,2,3,4]),后者返回tensor([0,2,4,6,8])
torch.linspace(0,10,steps=4)返回一维tensor([0.00, 3.33, 6.67, 10.00]),长度为steps,间距相等
torch.logspace(0,1,steps=4)返回一维tensor([1.00, 2.15, 4.64, 10.00]), f i = 1 0 x f_i=10^x fi=10x x x x等距分布在[0,1]
torch.randperm(10)返回0到9这10个整数的一维乱序张量
torch.rand(3,4)返回3×4张量,数值服从[0,1]均匀分布
torch.randn(3,4)返回3×4张量,数值服从标准正态分布
torch.rand_like(a)返回与a同shape的张量,数值服从[0,1]均匀分布
torch.randint(1,10, [3,4])返回3×4的张量,数值从[1,10)随机选取
torch.normal(mean=,std=)返回服从正态分布的张量。mean和std均为张量(形状一致),输出张量每个元素服从每一对(mean,std)正态分布

2. 比较大小 & bool张量快速判断

以下5个函数,如果x.shape=(3,4,5),则y必须是【一个数字y.shape=(5)y.shape=(4, 5)y.shape=(3, 4, 5)】其中之一。返回值为与x同形状的bool型张量

API解释
x.ge(y)等价于x>=y
x.le(y)等价于x<=y
x.gt(y)等价于x>y
x.lt(y)等价于x<y
x.eq(y)等价于x==y

以下函数返回值为TrueFalse

API解释
x.equal(y)张量 x x x y y y必须同形状。 x x x y y y对应元素相等时返回True
torch.all(mask)若bool张量mask全为True,返回True;否则返回False
torch.any(mask)若bool张量mask存在True,返回True;否则返回False

3. 索引与切片

假设张量a=torch.rand(4,3,28,28),可以表示4张图片,每个图片3个channel,每个channel大小28×28

API解释内存共享
a[0,0,2,4]每个维度指定索引,返回确切的元素值
a[:2]仅在第0个维度上,截取第0到2(不含)张照片,形状为(2,3,28,28)内存共享
a[0:2, :, :, -1:]第0个维度同上,第1,2维度上全取,第3维度上取最后一个。中间连续:可用...代替
a[:, :, 0:10:3]第0,1维度上全选。第2维度上从0到10步长为3选取,效果同range(0,10,3)
a.index_select(0, torch.tensor([0,2]))第0维度上,选择第0,2张图片。内存不共享,返回复制体,下同×
a.masked_select(mask) a[mask]bool张量maska同形状,抽取a中对应mask位置为True元素,返回一维张量×
a.take(indices)a看作一维,抽出一维张量indices指定位置的元素,返回一维张量

4. 维度变换

假设张量a的形状为(1,4,3,1)b的形状为(2,3)

API解释内存共享
a.view(3,4)返回张量a形状为3×4的形式,参数填-1则自动计算
a.squeeze(0)若第0维度是1,则删掉这个维度;否则不变。实际数据不变
a.unsqueeze(0)squeeze相反,增加第0维度,大小是1
a.expand(4,-1,-1,-1)第0维度“复制”为4份,其他维度不变。“复制”只能是大小为1的维度
a.repeat(4,-1,-1,-1)功能同expand。但repeat深拷贝×

转置操作后,原内存中的数据会发生变化,不再连续,可执行contiguous()获得连续形式

API解释内存共享
b.t()返回二维张量b的转置
a.transpose(0,2)在第0和2维度上进行转置;只能在2个维度上转置
a.permute(2,1,0,3)多维度转置。如第2维度转到第0维度;必须填满所有维度

5. 拼接与拆分

假设张量a的形状为(4,3,2)b的形状为(4,3,2)

API解释内存共享
torch.cat([a,b],dim=0)在第0维度上拼接ab,返回形状(8,3,2)。需保证除第dim维外,形状相同×
torch.stack([a,b],dim=1)在第1维度上组合ab,返回形状(4,2,3,2)。需保证ab形状完全相同×
a.split(split_sizes=1,dim=0)在第dim维度上拆分,每份大小为split_sizes。与cat()互为反操作,但与a共享内存
a.split([1,2,1],dim=0)在第dim维度上拆分,每份大小对应split_sizes每个元素
a.chunk(chunks=4,dim=3)在第dim维度上拆分,平均分为chunks

6. 基本运算

下表运算中,有两种变形:1.torch.*(),如torch.add(a, b)等价于a.add(b);2.a.*_(),如a.add_(b)表示a本身的值会被同时修改为a+b

API解释内存共享
a.add(b) a+b对应元素相加。a.add_(b)则a本身会被修改为a+b的值,下面的所有函数都是。
a.sub(b) a-b对应元素相减
a.mul(b) a*b对应元素相乘
a.div(b) a/b其中b不能是整数类型
x.dot(y)点乘(内积),xy必须都是一维(即向量)
x.mm(y)矩阵乘法,xy必须均为二维(即矩阵)
x.matmul(y) x@y矩阵乘法,兼容dot()mm()
a.pow(2)等价于a**2 求张量每个元素的n次方
a.sqrt()求张量每个元素的平方根。square root
a.rsqrt()sqrt()基础上对每个元素取倒数
a.exp()对每个元素 x x x,求 e x e^x ex
a.log()对每个元素 x x x,求 l n x ln x lnx,即以 e e e为底的对数
a.floor()对每个元素向下取整
a.ceil()对每个元素向上取整
a.trunc()对每个元素取整数部分
a.frac()对每个元素取小数部分
a.round()对每个元素四舍五入,返回类型为torch.FloatTensor
a.clamp(1,3)小于1的元素改为1,大于3的元素改为3。第二参数不填则不限制。内存深拷贝×

7. 数据统计

下表函数均含有默认参数dim=None, keepdim=False,返回值分2种情况:

  1. dim=None,表示对所有元素进行操作,返回一个标量
  2. dim=i,在第i维度上操作,同时可指定keepdim=True是否保持原张量的维数,返回一个元组(values,indices)values是值张量,indices是值所在的下标组成的张量
API解释
a.max()最大值
a.min()最小值
a.median()中位数,偶数个元素时取左中值。max()返回值为max对象,取其values成员才是tensor!min(),median()同理
a.sum()求和
a.mean()均值
a.prod()累乘
a.argmax()最大值的索引,多个最大值时返回最大索引
a.argmin()最小值的索引,多个最小值时返回最大索引
a.norm(p=,dim=None,keepdim=False)求张量第dim维的第p范数,即 ∑ i x i p p \sqrt[p]{\sum_i x_i^p} pixip

以下函数默认dim=None但实际上运算时默认dim=-1,即最低维上。故一定返回一个元组(values,indices)

API解释
b.topk(k=, dim=None, largest=True, sorted=True)在第dim维度上求前k大(largest=False则前k小)
b.kthvalue(k=,dim=None,keepdim=False)在第dim维度上求第k小

8. 复杂操作

API解释
torch.where(condition,x,y)三个参数为同形状张量, o u t ( i ) = { x i , i f ( c o n d i t i o n i    i s    T r u e ) y i , o t h e r w i s e out(i)=\begin{cases}x_i, & if (condition_i \; is \; True) \\ y_i, & otherwise\end{cases} out(i)={xi,yi,if(conditioniisTrue)otherwise
torch.gather(a,dim=,index=)张量aindex的形状必须在同一维度空间中。从第dim维度看index,抽取a中第dim维度对应的那个位置的元素。返回张量与index形状相同
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

雪的期许

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值