torch.cat()函数

一张图像在计算机中的表示通常为三维tensor(张量),即[channels,height,width] 。也就是一张彩色图片通常有三色通道(R,G,B)组成,高和宽也就是常说的照片大小,比如224x224

在图像处理的时候会增加一个变量batch_size,也就是把多少张图片作为一批进行处理。所以就变成了四维张量,即[batch_size,channels,heigth,width],也即是[批量大小,通道数,高,宽]

如何判断一个tensor是几维张量最简单的办法就是看中括号数。例如 [[[[1,2,3]]]],是四维张量。

torch.cat()函数,官方文档是这样写的torch.cat(tensors, dim=0, *, out=None),也就是有两个参数,一个是要合并的张量,一个是在哪个维度上进行合并。

废话少说开始演示。

import torch

a=torch.tensor([[[[1,1,1],[2,2,2]]]])

b=torch.tensor([[[[3,3,3],[4,4,4]]]])

print(a.shape,b.shape)

#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])

定义了两个四维张量。维度都为[1,1,2,3],即批量大小为1,通道为1,高为2,宽为3

import torch

a=torch.tensor([[[[1,1,1],[2,2,2]]]])

b=torch.tensor([[[[3,3,3],[4,4,4]]]])

print(a.shape,b.shape)

#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])

#在维度0上面进行合并

x=torch.cat((a,b),dim=0)

print(x.shape)

#torch.Size([2, 1, 2, 3])

在维度0上进行合并,然后输出维度为[2,1,2,3],所以得出结论 四维张量在0维合并的时候 其实是在批量大小维度上进行合并。

import torch

a=torch.tensor([[[[1,1,1],[2,2,2]]]])

b=torch.tensor([[[[3,3,3],[4,4,4]]]])

print(a.shape,b.shape)

#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])

#在维度0上面进行合并

x=torch.cat((a,b),dim=0)

print(x.shape)

#torch.Size([2, 1, 2, 3])

#在维度1上进行合并

x=torch.cat((a,b),dim=1)

print(x.shape)

#torch.Size([1, 2, 2, 3])

在1维度上进行合并,输出维度为[1,2,2,3],即在1维上合并是在通道维度上进行合并。

import torch

a=torch.tensor([[[[1,1,1],[2,2,2]]]])

b=torch.tensor([[[[3,3,3],[4,4,4]]]])

print(a.shape,b.shape)

#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])

#在维度0上面进行合并

x=torch.cat((a,b),dim=0)

print(x.shape)

#torch.Size([2, 1, 2, 3])

#在维度1上进行合并

x=torch.cat((a,b),dim=1)

print(x.shape)

#torch.Size([1, 2, 2, 3])

#在维度2上进行合并

x=torch.cat((a,b),dim=2)

print(x.shape)

#torch.Size([1, 1, 4, 3])

在维度2上进行合并,输出维度为[1,1,4,3]。即在2维上进行合并是在高上进行合并(也可以说是在行维度进行合并)

import torch

a=torch.tensor([[[[1,1,1],[2,2,2]]]])

b=torch.tensor([[[[3,3,3],[4,4,4]]]])

print(a.shape,b.shape)

#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])

#在维度0上面进行合并

x=torch.cat((a,b),dim=0)

print(x.shape)

#torch.Size([2, 1, 2, 3])

#在维度1上进行合并

x=torch.cat((a,b),dim=1)

print(x.shape)

#torch.Size([1, 2, 2, 3])

#在维度2上进行合并

x=torch.cat((a,b),dim=2)

print(x.shape)

#torch.Size([1, 1, 4, 3])

#在维度3上进行合并

x=torch.cat((a,b),dim=3)

print(x.shape)

#torch.Size([1, 1, 2, 6])

在维度3上进行合并,输出维度为[1,1,2,6],即在3维上进行合并是在宽维度进行合并(也可以说是列)

注:在拼接时 除了选择拼接的维度可以不同,其他维度要相同。什么意思?看代码

import torch

#定义两个变量[batch_size,channel,height,width]

a=torch.randn(size=(1,1,2,3))

b=torch.randn(size=(1,2,2,3))

#选择在1维度进行合并(也就是通道维度),注意a,b的通道维度不同,其他维度都相同。

x=torch.cat((a,b),dim=1)

print(x.shape)

#torch.Size([1, 3, 2, 3])

也就是选择合并的那个维度可以不同,其他维度要相同

如果不同,报错,如下。

import torch

#定义两个变量[batch_size,channel,height,width]

a=torch.randn(size=(1,1,2,3))

b=torch.randn(size=(2,2,2,3))

#选择在1维度进行合并(也就是通道维度),注意a,b的批量大小不同,维度不同,其他维度都相同。

x=torch.cat((a,b),dim=1)

print(x.shape)

#RuntimeError: Sizes of tensors must match except in dimension 1. Got 1 and 2 in dimension 0

可以看到当我们选择在通道维度合并时(通道数可以不同),但是其他的维度要相同(下面的a,b的批量大小也不同)。所以直接报错。

原文链接:https://blog.csdn.net/zwb619/article/details/127022873

  • 6
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值