【Pytorch入门】合并与分割

Pytorch学习笔记——合并与分割

torch.cat

torch.cat(tensors,dim=0,out=None)

作用torch.cat是将两个tensor拼接在一起

torch.cat主要有两种用法,一种是按维数0拼接(按行拼接),另一种是按维数1拼接(按列拼接)

按维数0拼接(按行拼接)

example

我们首先先创建两个二维张量x和y,分别是1行3列和2行3列

import torch
x=torch.zeros(1,3)#1*3的tensor
y=torch.zeros(2,3)#2*3的tensor
print(x)
print(y)

在这里插入图片描述

torch.cat((x,y),dim=0)就表示按维数0拼接x,y

torch.cat((x,y),dim=0)

在这里插入图片描述

两个张量必须用[]或者()包裹,否则会出现如下报错

torch.cat(x,y,dim=0)

在这里插入图片描述

按维数0拼接(按行拼接)

example

依旧是创建两个张量x和y,分别是2行3列和2行1列,利用torch.cat进行拼接,此时dim=1

import torch
x=torch.zeros(2,3)
y=torch.zeros(2,1)
print(x)
print(y)
torch.cat((x,y),dim=1)

在这里插入图片描述

cat的其他用法

除了上述操作之外,cat还可以把list中的tensor拼接起来

x=torch.tensor([[i] for i in range(4)])#tensor类型
y=[x**2 for i in range(4)]#list类型
print(x)
print(y)
t=torch.cat((y),dim=1)
print(t)

在这里插入图片描述

总结

1、在使用torch.cat进行拼接时,除拼接维数dim数值可不同外,其余维数数值需相同
2、两个拼接的tensor需要用[]或者()包裹起来
3、dim=0表示按行拼接,dim=1表示按列拼接
4、cat还可以把list中的tensor拼接起来

ps:二维理解起来比较方便,但是实践中常用到的均是三维起,当维数>2时,dim=0表示以第0维拼接,
对于一个张量的维度,有几个[]就是几维,因此dim=0表示将第一个[]的内容进行拼接,同理dim=n

torch.stack

作用:与torch.cat相同,torch.stack也是用于拼接两个张量,同样有张量列表和维度两个参数

区别于torch.cat:

torch.stack()函数要求输入张量的大小完全相同,得到的张量的维度会比输入的张量的大小多1,并且多出的那个维度就是拼接的维度,那个维度的大小就是输入张量的个数。这样讲可能有点难以理解,没关系,我们来看个例子。

example

为了方便理解,我们结合一个实际背景,假设现在我们要统计两个班级的成绩,已知两个班级的人数均为32人,考试科目为8门,构建两个张量x和y

import torch
x=torch.rand(32,8)
y=torch.rand(32,8)
print(torch.cat((x,y),dim=0).shape)
print(torch.cat((x,y),dim=1).shape)

在这里插入图片描述

如果我们用torch.cat进行拼接的话,返回的将会是一个shape为(64,8)即64人,8门课或者(32,16)即32人,16门课,显然此时运用torch.cat不太合适

z=torch.stack((x,y),dim=0)
print(z.shape)

torch.Size([2, 32, 8])

使用torch.stack我们发现,返回的tensor在第0维前面增加了一个维度,而这个维度的维数就是拼接的张量的个数。我们也可以增加在其他的维度前。

print(torch.stack((x,y),dim=2).shape)

torch.Size([32, 8, 2])

torch.split

torch.split(tensor,split_size_or_sections,dim=0)

作用:把一个tensor分割成若干个小tensor

torch.split主要是按照长度进行拆分,第一种情况长度一致,另一种长度不一致

example

我们将之前合并时构建的张量z用来拆分

第一种情况:长度一致——可以直接传递数字,将张量z拆分成两个大小相同的张量a和b

a,b=z.split(1,dim=0)
a.shape,b.shape

(torch.Size([1, 32, 8]), torch.Size([1, 32, 8]))

第二种情况:长度不一致——将对应的长度组成list传递进去
此时我们构造张量z的shape为(6,32,8)用于拆分,将张量z拆分成a,b,c三个张量,其shape分别为(1,32,8),(2,32,8),(3,32,8)

z=torch.rand(6,32,8)
a,b,c=z.split([1,2,3],dim=0)
a.shape,b.shape,c.shape

(torch.Size([1, 32, 8]), torch.Size([2, 32, 8]), torch.Size([3, 32,
8]))

torch.chunk

torch.chunk(intput,chunks,dim=0)

作用:把一个tensor均匀分割成若干个小tensor

参数说明

  • input:需要分割的tensor
  • chunks:想均匀分割的分数,如果该tensor在你要进行分割的维度上的size不能被chunks整除,则最后一份会略小(也可能为空)
  • dim:表示分割的维度

example

a,b,c=torch.chunk(z,3,dim=0)
a.shape,b.shape,c.shape

(torch.Size([2, 32, 8]), torch.Size([2, 32, 8]), torch.Size([2, 32,
8]))

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值