Pytorch-torch.stack 与 torch.cat的拼接过程

本文详细解释了PyTorch中的torch.unsqueeze函数,它是如何在张量的特定维度上添加长度为1的新轴。接着介绍了torch.cat,用于沿指定维度拼接多个张量。最后,探讨了torch.stack,它在新的维度上堆叠张量,需要所有输入张量具有相同的形状。文章强调理解unsqueeze和cat对于掌握stack的重要性,并批评了中文资源中对此的解释往往不够清晰。
摘要由CSDN通过智能技术生成

1.torch.unqueeze

torch.unsqueeze(inputdim=None) → Tensor

 torch.unsqueeze :给input在原dim出添加一个长为1的轴newdim,示例如下:

由此可见,torch.unsequeeze作用只是添加一个长度为1的新轴

  • dim = 0,x.shape: (2,3) -> (1,2,3)
  • dim = 1,x.shape: (2,3) -> (2,1,3)
  • dim = 2,x.shape: (2,3) -> (2,3,1)
  • 概括为文字描述:在给dim=i之后的轴的数据套上一个方括号[],如dim=1,将x[0,:],x[1,:]分别用括号括起来,即使x在原0和1轴之间添加一个新轴。
  • 另外逆向操作是torch.squeeze,可以去掉指定的一个长度为1的轴,即去掉一个括号

2.torch.cat

torch.cat(tensorsdim=0*out=None) → Tensor

tensors可以使列表或者元组, 同时这些tensor除了指定的dim外,其余各轴长度必须对应相同

 torch.cat:用来拼接多个tensor,拼接时会按照指定的dim来进行,示例如下:

 由此可见,

  • dim的取值范围是[-x.shape,x.shape-1],负数就是倒数的意思(python特色下标)
  • dim = 0,res.shape = (4,3),拼接dim=0轴上的元素,dim=0上的元素都是一个长度为3的向量,将次向量看成整体即为dim=0轴上的元素,最后x可视为[A,B]向量,x+1即为[A+1,B+1],将这两个拼接就得到了[A,B,A+1,B+1],展开就是最终shape=(4,3)的结果
  • dim = 1,res.shape = (2,6), 拼接dim=1轴上的元素,dim=1上的元素都是一个个标量,将x,x+1对应dim=1轴上元素进行拼接,如x的0轴上的第一个元素是[0,1,2], 在dim=1轴上展开分别是0,1,2三个标量,将这三个标量与x+1对应0轴第一个元素在1轴展开的三个元素1,2,3进行拼接,继续如此拼接x,x+1的0轴的第二个元素,即得到了最终结果
  • 总的来说torch.cat与各种语言中concat函数并无区别,只是维度可能更高了,被拼接的元素不再限于单个数字或字符这种标量,可能是一个多维张量,便于理解可以将该维度上的非标量元素看成整体,当成标量来拼接。
  • 有两个逆向操作 torch.split() and torch.chunk()

3.torch.stack

 torch.stack(tensorsdim=0*out=None) → Tensor

 tensors:与torch.cat一样,可以使张量列表或元组

torch.stack:与torch.cat类似也是拼接,但不同的是:

  • tensors中的张量必须shape都相同
  • dim取值范围[-x.shape-1,x.shape]
  • torch.stack会在原shape中的dim位置新增一个轴,并按dim轴来cat所有张量 。便于理解,可以认为torch.stack会先将所有张量做torch.unsqueeze(tensors,dim)的操作来新增轴,再将得到的张量列表做torch.cat(tensors_unsqueezed,dim)
  • 示例如下

 这里演示一个难理解点的dim=2的情况:

  1. 先通过torch.unsqueeze(x,dim=2), torch.unsqueeze(x+1,dim=2)来增加一个dim=2轴来得到两个新张量,记为y和z

 2.通过torch.cat((y,z),dim=2)来拼接,最终得到torch.stack((x,x+1),dim=2)的结果

        总的来说,torch.stack就是在一个新轴上cat多个张量,建议先理解unsqueeze和cat,这样才能彻底明白stack。

网上一搜torch.stack和torch.cat中文内容全是抄来抄去,明明都没写清楚 ,只简单说明shape的变化(这谁会理解不了),一堆人照抄假装理解了,可见中文环境是真滴差。这篇文章我参考了pytorch官方文档,也看了油管上的视频讲解,最终结合实践,终于搞懂了stack和cat的过程。

文章编写仓促,可能存在细微疏漏望多多包涵。

pytorch官方文档

torch.stack — PyTorch 2.0 documentation

油管视频 

Stack vs Concat in PyTorch, TensorFlow & NumPy - Deep Learning Tensor Ops

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值