torch.cat() 和 torch.stack()

网上很多的示例,都在讨论二维数据(矩阵),单是对于做图像与深度学习的人来说均是三维起步,一般都是4维,下边以4维数据举例

对于pytorch中的堆叠与拼接函数stack与cat,二者还是有一定的不同

torch.cat这是一个拼接函数(姑且这么说)

直接上例子

a0=torch.Tensor([[[[1,1,1,1],[2,2,2,2]]]])
a1=torch.Tensor([[[[3,3,3,3],[4,4,4,4]]]])
torch.Size([1, 1, 2, 4])

torch.cat((a0,a1),dim=0).type(torch.FloatTensor)

tensor([[[[1., 1., 1., 1.],
          [2., 2., 2., 2.]]],


        [[[3., 3., 3., 3.],
          [4., 4., 4., 4.]]]])
torch.Size([2, 1, 2, 4])

上边dim=0,为以第一维为基准拼接,对于一个张量的维度,有几个放括号就是几维,上边例子a0与a1均为4维张量,因此以第0维拼接就是将第一个中括号内的内容进行拼接。最终的尺度大小为(2,1,2,4)

a0=torch.Tensor([[[[1,1,1,1],[2,2,2,2]]]])
a1=torch.Tensor([[[[3,3,3,3],[4,4,4,4]]]])
torch.cat((a0,a1),dim=1).type(torch.FloatTensor)

tensor([[[[1., 1., 1., 1.],
          [2., 2., 2., 2.]],

         [[3., 3., 3., 3.],
          [4., 4., 4., 4.]]]])
torch.Size([1, 2, 2, 4])

以第1维进行拼接,将第二个括号内的内容进行拼接

a0=torch.Tensor([[[[1,1,1,1],[2,2,2,2]]]])
a1=torch.Tensor([[[[3,3,3,3],[4,4,4,4]]]])
torch.cat((a0,a1),dim=2).type(torch.FloatTensor)
tensor([[[[1., 1., 1., 1.],
          [2., 2., 2., 2.],
          [3., 3., 3., 3.],
          [4., 4., 4., 4.]]]])
torch.Size([1, 1, 4, 4])

 

a0=torch.Tensor([[[[1,1,1,1],[2,2,2,2]]]])
a1=torch.Tensor([[[[3,3,3,3],[4,4,4,4]]]])
torch.cat((a0,a1),dim=3).type(torch.FloatTensor)
tensor([[[[1., 1., 1., 1., 3., 3., 3., 3.],
          [2., 2., 2., 2., 4., 4., 4., 4.]]]])
torch.Size([1, 1, 2, 8])

以上为torch.cat的使用方法,两个tensor的维度必须一致,最终生成的张量的维度也没有变化,但是torch.cat就不一样了

torch.stack

依然使用上边的例子,看看这个函数的功能

a0=torch.Tensor([[[[1,1,1,1],[2,2,2,2]]]])
a1=torch.Tensor([[[[3,3,3,3],[4,4,4,4]]]])
torch.stack((a0,a1),dim=0).type(torch.FloatTensor)

tensor([[[[[1., 1., 1., 1.],
           [2., 2., 2., 2.]]]],



        [[[[3., 3., 3., 3.],
           [4., 4., 4., 4.]]]]])
torch.Size([2, 1, 1, 2, 4])

 

a0=torch.Tensor([[[[1,1,1,1],[2,2,2,2]]]])
a1=torch.Tensor([[[[3,3,3,3],[4,4,4,4]]]])
torch.stack((a0,a1),dim=1).type(torch.FloatTensor)
tensor([[[[[1., 1., 1., 1.],
           [2., 2., 2., 2.]]],


         [[[3., 3., 3., 3.],
           [4., 4., 4., 4.]]]]])
torch.Size([1, 2, 1, 2, 4])

从上边两个例子,可以看出,对于torch.stack来说,会先将原始数据维度扩展一维,然后再按照维度进行拼接,具体拼接操作同torch.cat类似

 

贴个torch.stack()官方文档的截图
在这里插入图片描述

dim代表沿着哪个维度进行堆叠

举个例子:

dim=0时:(dim不写时,默认为0)

在这里插入图片描述

a: 2x3 ; b: 2x3 ; c: 2x2x3

在这里插入图片描述

dim=1时:略

dim=2时:

在这里插入图片描述

a: 2x3 ; b: 2x3 ; c: 2x3x2

在这里插入图片描述

总结

在这里插入图片描述

参考链接

https://zhuanlan.zhihu.com/p/70035580

https://www.pianshen.com/article/10611294719/

 

  • 3
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值