torch.cat:想要去对那哪一个维度进行concat就必须要保证其他维度的大小是一样的。
如a的shape为(2,2,3),b的shape为(3,2,3),那么在维度0上是可以进行concat的,concat过后得到(5,2,3)而在维度1和2维度2上是不可以进行concat的。
代码如下:
import torch
a=torch.tensor([[[1,2,3],[4,5,6]],
[[7,8,9],[10,11,12]]])
print("a的shape为:",a.shape)
print("tensor_a:",a)
b=torch.tensor([[[13,14,15],[16,17,18]],
[[19,20,21],[22,23,24]],
[[25,26,27],[28,29,30]]])
print("b的shape为:",b.shape)
print("tensor_b:",b)
c=torch.cat((a,b),0)
print("c的shape为:",c.shape)
print("tensor_c:",c)
print(c)
输出为:
a的shape为: torch.Size([2, 2, 3])
tensor_a: tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]])
b的shape为: torch.Size([3, 2, 3])
tensor_b: tensor([[[13, 14, 15],
[16, 17, 18]],
[[19, 20, 21],
[22, 23, 24]],
[[25, 26, 27],
[28, 29, 30]]])
c的shape为: torch.Size([5, 2, 3])
tensor_c: tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]],
[[13, 14, 15],
[16, 17, 18]],
[[19, 20, 21],
[22, 23, 24]],
[[25, 26, 27],
[28, 29, 30]]])
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]],
[[13, 14, 15],
[16, 17, 18]],
[[19, 20, 21],
[22, 23, 24]],
[[25, 26, 27],
[28, 29, 30]]])
在图像领域,我们经常对特征图进行一个channel的concat,如特征图1大小为(16,16,256),特征图2大小为(16,16,512),将特征图1和特征图2进行一个channel上的concat,concat过后的特征图大小为(16,16,768)
,
下面以lasot数据集中两张图片大小为(3,720,1280)作为concat演示
00000001.jpg
00000002.jpg
import torch
import torchvision.transforms as transforms
import cv2
img_1=cv2.imread('00000001.jpg')
img_2=cv2.imread('00000002.jpg')
print("opencv_img_1的shape:",img_1.shape)
transfer=transforms.ToTensor()
img_1_tensor=transfer(img_1)
print("img_1的shape:",img_1_tensor.shape)
img_2_tensor=transfer(img_2)
print("img_2的shape:",img_2_tensor.shape)
img_1_2_cat_tensor=torch.cat((img_1_tensor,img_2_tensor),0)
print("img_3的shape:",img_1_2_cat_tensor.shape)
输出为:
opencv_img_1的shape :(720, 1280, 3)
img_1的shape: torch.Size([3, 720, 1280])
img_2的shape: torch.Size([3, 720, 1280])
img_3的shape: torch.Size([6, 720, 1280])
注:1.使用torch.cat注意维度的匹配,对某一个维度进行concat必须保证其他维度相同。
2.opencv中图片的shape和tensor中的shape不一样,tensor是把图像的channel放在最前面。