torch.cat()用法

torch.cat:想要去对那哪一个维度进行concat就必须要保证其他维度的大小是一样的。
如a的shape为(2,2,3),b的shape为(3,2,3),那么在维度0上是可以进行concat的,concat过后得到(5,2,3)而在维度1和2维度2上是不可以进行concat的。
代码如下:

import torch

a=torch.tensor([[[1,2,3],[4,5,6]],
                [[7,8,9],[10,11,12]]])
print("a的shape为:",a.shape)
print("tensor_a:",a)
b=torch.tensor([[[13,14,15],[16,17,18]],
              [[19,20,21],[22,23,24]],
              [[25,26,27],[28,29,30]]])
print("b的shape为:",b.shape)
print("tensor_b:",b)
c=torch.cat((a,b),0)
print("c的shape为:",c.shape)
print("tensor_c:",c)
print(c)

输出为:

a的shape为: torch.Size([2, 2, 3])
tensor_a: tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12]]])
b的shape为: torch.Size([3, 2, 3])
tensor_b: tensor([[[13, 14, 15],
         [16, 17, 18]],

        [[19, 20, 21],
         [22, 23, 24]],

        [[25, 26, 27],
         [28, 29, 30]]])
c的shape为: torch.Size([5, 2, 3])
tensor_c: tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12]],

        [[13, 14, 15],
         [16, 17, 18]],

        [[19, 20, 21],
         [22, 23, 24]],

        [[25, 26, 27],
         [28, 29, 30]]])
tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12]],

        [[13, 14, 15],
         [16, 17, 18]],

        [[19, 20, 21],
         [22, 23, 24]],

        [[25, 26, 27],
         [28, 29, 30]]])

在图像领域,我们经常对特征图进行一个channel的concat,如特征图1大小为(16,16,256),特征图2大小为(16,16,512),将特征图1和特征图2进行一个channel上的concat,concat过后的特征图大小为(16,16,768)
,
下面以lasot数据集中两张图片大小为(3,720,1280)作为concat演示
00000001.jpg
在这里插入图片描述00000002.jpg
在这里插入图片描述

import  torch
import torchvision.transforms as transforms
import cv2
img_1=cv2.imread('00000001.jpg')
img_2=cv2.imread('00000002.jpg')
print("opencv_img_1的shape:",img_1.shape)
transfer=transforms.ToTensor()
img_1_tensor=transfer(img_1)
print("img_1的shape:",img_1_tensor.shape)
img_2_tensor=transfer(img_2)
print("img_2的shape:",img_2_tensor.shape)
img_1_2_cat_tensor=torch.cat((img_1_tensor,img_2_tensor),0)
print("img_3的shape:",img_1_2_cat_tensor.shape)

输出为:

opencv_img_1的shape :(720, 1280, 3)
img_1的shape: torch.Size([3, 720, 1280])
img_2的shape: torch.Size([3, 720, 1280])
img_3的shape: torch.Size([6, 720, 1280])

注:1.使用torch.cat注意维度的匹配,对某一个维度进行concat必须保证其他维度相同。
2.opencv中图片的shape和tensor中的shape不一样,tensor是把图像的channel放在最前面。

参考:
pytorch:把图片数据转化成tensor

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值