【Pytorch】torch.cat() Pytorch中对于张量拼接函数的用法和示例以及dim=-1的应用场景

本文详细解释了如何在PyTorch中使用`torch.cat()`函数合并张量,特别关注dim=-1的情况,以及它如何根据输入张量的维度动态扩展。通过实例展示了不同dim值下的合并效果,并揭示了dim=-1相当于沿张量内最内层维度进行合并,无需指定具体形状。
摘要由CSDN通过智能技术生成

目录

背景

张量例子

dim取值范围在[0, len(inputs[0])]时

dim取值为-1时


背景

遇到一个张量合并问题,对于dim=-1的合并情况不清楚,在寻找博客案例时都没有提到dim=-1的情况,所以自己写了几个案例补充一下对于维度为-1的情况时怎样的

张量例子

对于理解torch.cat()和torch.stack()直接用例子理解是最好用的。特别注意dim维度的理解

给出以下两个张量的例子:

x1 = torch.tensor([[11,21,31],[34,51,23]], dtype=torch.int)
x2 = torch.tensor([[34,56,67],[62,91,56]], dtype=torch.int)

dim取值范围在[0, len(inputs[0])]时

其实乍一看这个len(inputs[0])非常难理解,所以不要这么去记。

先来看例子:

1.对x1,x2这两个张量进行合并,注意此时这两个张量的维数都为2并且其形状都为(2,3)

1):在维数等于0时进行合并

x3 = torch.cat((x1,x2), dim=0)
x3: tensor([[11, 21, 31],
        [34, 51, 23],
        [34, 56, 67],
        [62, 91, 56]], dtype=torch.int32)
x3 shape: torch.Size([4, 3])

2):在维数等于1时进行合并

x3 = torch.cat((x1,x2), dim=1)
x3: tensor([[11, 21, 31, 34, 56, 67],
        [34, 51, 23, 62, 91, 56]], dtype=torch.int32)
x3 shape: torch.Size([2, 6])

3):在维数等于2时进行合并

x3 = torch.cat((x1,x2), dim=2)

IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

在dim=2时出现报错

-------------------------------------------------------------------------------------------------------------------------------

看到这里可能还是不明白为什么在维度等于2时出现了报错,这个维度究竟是怎么计算的

所以我们再次扩展一下x1和x2使其成为三维张量

x1 = torch.tensor([[[11,21,31],[34,51,23]],[[76,56,89],[35,74,62]]], dtype=torch.int)
x2 = torch.tensor([[[34,56,67],[62,91,56]],[[67,35,98],[38,79,25]]], dtype=torch.int)

那么现在这个两个张量的形状为(2,2,3)

2. 对于x1和x2再次进行合并

x3 = torch.cat((x1,x2), dim=0)
x3: tensor([[[11, 21, 31],
         [34, 51, 23]],

        [[76, 56, 89],
         [35, 74, 62]],

        [[34, 56, 67],
         [62, 91, 56]],

        [[67, 35, 98],
         [38, 79, 25]]], dtype=torch.int32)
x3 shape: torch.Size([4, 2, 3])

x3 = torch.cat((x1,x2), dim=1)
x3: tensor([[[11, 21, 31],
         [34, 51, 23],
         [34, 56, 67],
         [62, 91, 56]],

        [[76, 56, 89],
         [35, 74, 62],
         [67, 35, 98],
         [38, 79, 25]]], dtype=torch.int32)
x3 shape: torch.Size([2, 4, 3])

x3 = torch.cat((x1,x2), dim=2)
x3: tensor([[[11, 21, 31, 34, 56, 67],
         [34, 51, 23, 62, 91, 56]],

        [[76, 56, 89, 67, 35, 98],
         [35, 74, 62, 38, 79, 25]]], dtype=torch.int32)
x3 shape: torch.Size([2, 2, 6])

x3 = torch.cat((x1,x2), dim=3)
# IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

可以观察到这次在dim为2时并没有出现报错,因为这一次输入的x1和x2的形状大小为(2,2,3)有三个维度的信息,所以dim可以取值0,1,2

总结:dim指代的是以哪一个维度信息进行合并,在Pytorch中维度信息的表示方式是从外向内的。这一过程可以形象的理解为拨香蕉。(可以回看张量例子中的x1,x2的形状)所以当dim=0时指代的是按照最外层的list进行拼接,当dim=1时指代按照第二层list进行拼接,以此类推。所以输入的张量维数决定dim的取值范围。再会看定义的范围就非常好理解了

dim取值为-1时

先说结论:dim=-1与dim=len(inputs[0])一样,即适用最里面的list进行合并。在python中-1的用法类似,均指代最后一个,所以dim=-1的情况就不难理解了

接下来说一下dim=-1的优点:

· 不需要知道输入具体的形状,适用于想要对于dim=len(inputs[0])进行扩展的情况

· 特别适合末端形状维度的扩展,举个例子⬇

仍然使用最开始的x1张量:

x1 = torch.tensor([[11,21,31],[34,51,23]], dtype=torch.int)

另外增加一个零向量

x0 = torch.zeros(x1.shape[0], 1)

它打印出来的结果为

x0:
tensor([[0.],[0.]])
torch.Size([2, 1])

将x0和x1合并:

x3 = torch.cat((x1,x0), dim=-1)

x3: tensor([[11., 21., 31.,  0.],
        [34., 51., 23.,  0.]])
x3 shape: torch.Size([2, 4])

可以看到形状从(2,3)扩充为(2,4)

在很多的数据处理中,这一步十分重要

  • 7
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

江川轻寒

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值