PyTorch关于张量的拆分与合并方法: stack(), cat(), split(), chunk()

张量的拆分与合并常用于深度学习中,接下来主要介绍四个关于张量的拆分与合并方法:stack,cat,split,chunk


stack

  1. 功能:在一个新的维度上堆叠Tensor矩阵。
  2. 用法:torch.stack( tensors, dim=0, out=None) 返回值为tensor

注意:所有将被堆叠的Tensor必须保持同样尺度,且堆叠后会产生新的维度

a,b,c=torch.rand(2,3),torch.rand(2,3),torch.Tensor([])
# 测试out参数
torch.stack([a,b],dim=0,out=c)
print(c)
c=torch.stack([a,b],dim=0)
print(c)
# 测试dim参数
c=torch.stack([a,b],dim=1)
print(c)
# 测试tensors参数
c = torch.stack([a,b],dim=0) # 此时tensors是一个Tensor张量序列
d = torch.stack([a,torch.rand(3,4)],dim=0) # 此时报错,原因是,两个张量尺度不同。

初始化:在这里插入图片描述
dim=0,dim=1
在这里插入图片描述
尺度不一致会报错
在这里插入图片描述

cat

  1. 功能:在给定的维度上堆叠Tensor序列。
  2. 用法:torch.cat( tensors, dim=0, out=None) 返回值为tensor

注意:所有将被堆叠的Tensor要么同样尺度,要么为空,且堆叠后特征维度不变

a,b,c = torch.rand(2,3),torch.rand(2,3),torch.Tensor()
torch.cat([a,b],dim=0)
torch.cat((a,b),dim=1

在这里插入图片描述
不同于stack,cat不会产生新的维度;

split

  1. 功能:拆分Tensor为若干块。
  2. 用法:torch.split( tensors, split_size_or_section,dim=0) 返回值为tensor
a = torch.rand(4,6)
# 两种拆分方式:
# 第一种
b,c = torch.split(a,split_size_or_section=2,dim=0)
# 第二种
b,c = torch.split(a,split_size_or_section=[2,2],dim=0)

在这里插入图片描述

chunk

  1. 功能:拆分Tensor为若干块。
  2. 用法:torch.chunk( tensors, chunks,dim) 返回值为tensor
a = torch.rand(4,6)
b,c = torch.chunk(a,2,dim=0)
b,c = torch.chunk(a,2,dim=1)

在这里插入图片描述


至此,简单介绍完成。有问题可以评论留言,共同学习。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值