torch.cat的参数dim到底是在哪个维度拼接

文章详细介绍了在PyTorch中使用torch.cat函数进行张量拼接时,如何根据dim参数改变拼接维度,并通过实例展示了不同dim值下的结果。当dim=0时,张量在批次维度拼接;dim=1时,拼接通道维度;而dim=2或3时,要求张量对应维度相同,否则无法拼接。
摘要由CSDN通过智能技术生成

如果将两个维度为 n * c * h * w 进行拼接,如果指定dim为以下值

  • dim = 0, 拼接后维度为 2n * c * h * w
  • dim = 1, 拼接后维度为 n * 2c * h * w
  • dim = 2,拼接后维度为 n * c * 2h * w
  • dim = 3,拼接后维度为 n * c * h * 2w

即 dim = i 就表示在第 i 维度度进行拼接,此时除第 i 维度数可以不同外, 其他维度必须相同, 否则无法拼接。

测试1:

x1 = torch.rand((1, 16, 32, 32))
y1 = torch.rand((1, 32, 32, 32))
  1. 在 dim = 0 拼接
out0 = torch.cat((x1, y1), dim = 0)

报错:RuntimeError: Sizes of tensors must match except in dimension 0. Got 16 and 32 in dimension 1 (The offending index is 1)
即 x1, x2 在其他维度不相等(x1(16, 32, 32),x2(32, 32, 32))
2. 在 dim = 1 拼接

out1 = torch.cat((x1, y1), dim = 1)
print(out1.size())

输出: torch.Size([1, 48, 32, 32]),即在dim = 1 上拼接后为 16 + 32 = 48
同理,在dim = 2 或者 dim = 3 维度拼接都会出错

测试1:测试两个维度一模一样的张量

in_put1 = torch.rand((1, 64, 8, 8))
in_put2 = torch.rand((1, 64, 8, 8))
out0 = torch.cat((in_put1, in_put2), dim = 0)
out1 = torch.cat((in_put1, in_put2), dim = 1)
out2 = torch.cat((in_put1, in_put2), dim = 2)
out3 = torch.cat((in_put1, in_put2), dim = 3)

我们的预期结果当然是
out0: (2, 64, 8, 8)
out1: (1, 128, 8, 8)
out2: (1, 64, 16, 6)
out2: (1, 64, 8, 16)
查看打印结果的确如此:
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值