Pytorch中torch.cat()

  1. torch.cat(tensors, dim=0)是将两个张量(tensor)拼接在一起,cat是concatenate的意思,即拼接。
    tensors:(或者tensors序列)——提供的非空tensor必须具有相同的shape,cat 维度除外。
    dim (int, optional): —— 拼接tensor的维度

  2. 例子:

>>> import torch
>>> A=torch.ones(2,3) #2x3的张量(矩阵)                                     
>>> A
tensor([[ 1.,  1.,  1.],
        [ 1.,  1.,  1.]])
>>> B=2*torch.ones(4,3)#4x3的张量(矩阵)                                    
>>> B
tensor([[ 2.,  2.,  2.],
        [ 2.,  2.,  2.],
        [ 2.,  2.,  2.],
        [ 2.,  2.,  2.]])
>>> C=torch.cat((A,B),0)#按维数0(行)拼接
>>> C
tensor([[ 1.,  1.,  1.],
         [ 1.,  1.,  1.],
         [ 2.,  2.,  2.],
         [ 2.,  2.,  2.],
         [ 2.,  2.,  2.],
         [ 2.,  2.,  2.]])
>>> C.size()
torch.Size([6, 3])
>>> D=2*torch.ones(2,4) #2x4的张量(矩阵)
>>> C=torch.cat((A,D),1)#按维数1(列)拼接
>>> C
tensor([[ 1.,  1.,  1.,  2.,  2.,  2.,  2.],
        [ 1.,  1.,  1.,  2.,  2.,  2.,  2.]])
>>> C.size()
torch.Size([2, 7])

总结:torch.cat((A,B),0),将A和B按改变维度0,不改变维度1的方向拼接 (即按行方向拼接,行数变而列数不变)
torch.cat((A,B),1),将A和B按改变维度1,不改变维度0的方向拼接 (即按列方向拼接,列数变而行数不变)

例子部分感谢my-GRIT大佬,链接:
https://blog.csdn.net/qq_39709535/article/details/80803003?spm=1001.2101.3001.6650.3&utm_medium=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7Edefault-3-80803003-blog-96428781.pc_relevant_aa&depth_1-utm_source=distribute.pc_relevant.none-task-blog-2%7Edefault%7ECTRLIST%7Edefault-3-80803003-blog-96428781.pc_relevant_aa&utm_relevant_index=5

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值