torch.cat和torch.sum的理解

本文介绍了如何使用PyTorch的torch.cat()函数进行张量拼接,以及torch.sum()函数对张量按指定维度求和。通过实例展示了不同dim参数下的操作效果,包括按行和按列拼接,以及高维拼接的应用。
摘要由CSDN通过智能技术生成

torch.cat((A,B),dim)torch.cat([A,B],dim)

拼接张量A和B,dim表示拼接方向,dim=0表示按行拼接(即两张量的列大小一致),反之dim=1表示按列拼接。如果是高维的话,就按照dim的数值方向拼接

import torch
A=torch.ones(2,3) #2x3的张量(矩阵)
B=2*torch.ones(4,3)#4x3的张量(矩阵)
C=torch.cat((A,B),0)#按维数0(行)拼接

D=2*torch.ones(2,4) #2x4的张量(矩阵)
E=torch.cat((A,D),1)#按维数1(列)拼接

# 高维拼接
>>> A = torch.zeros((1,2,3,4))
>>> B = torch.zeros((1,2,3,4))
>>> C = torch.cat((A,B),dim=2)
>>> C.shape
torch.Size([1, 2, 6, 4])

在这里插入图片描述

torch.sum(A,dim)
对张量A的某一维度求和

import torch
A=torch.ones(2,3) #2x3的张量(矩阵)
B=torch.sum(A,0)#按维数0(行)相加
C=torch.sum(A,1)#按维数0(列)相加

在这里插入图片描述

>>> import torch
>>> a = torch.ones([1,2,3,4])*2
>>> b = torch.ones([1,2,3,4])
>>> a1 = a.unsqueeze(dim=1)
>>> b1 = b.unsqueeze(dim=1)
>>> a1.shape
torch.Size([1, 1, 2, 3, 4])
>>> b1.shape
torch.Size([1, 1, 2, 3, 4])
>>> c = torch.cat([a1, b1], dim=1)
>>> c.shape
torch.Size([1, 2, 2, 3, 4])
>>> d1 = torch.sum(c,dim=0)
>>> d2 = torch.sum(c,dim=1)
>>> d3 = torch.sum(c,dim=2)
>>> d4 = torch.sum(c,dim=3)

>>> c
tensor([[[[[2., 2., 2., 2.],
           [2., 2., 2., 2.],
           [2., 2., 2., 2.]],
          [[2., 2., 2., 2.],
           [2., 2., 2., 2.],
           [2., 2., 2., 2.]]],
         [[[1., 1., 1., 1.],
           [1., 1., 1., 1.],
           [1., 1., 1., 1.]],
          [[1., 1., 1., 1.],
           [1., 1., 1., 1.],
           [1., 1., 1., 1.]]]]])
>>> d1
tensor([[[[2., 2., 2., 2.],
          [2., 2., 2., 2.],
          [2., 2., 2., 2.]],
         [[2., 2., 2., 2.],
          [2., 2., 2., 2.],
          [2., 2., 2., 2.]]],
        [[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],
         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]])

>>> d2
tensor([[[[3., 3., 3., 3.],
          [3., 3., 3., 3.],
          [3., 3., 3., 3.]],
         [[3., 3., 3., 3.],
          [3., 3., 3., 3.],
          [3., 3., 3., 3.]]]])
>>> d3
tensor([[[[4., 4., 4., 4.],
          [4., 4., 4., 4.],
          [4., 4., 4., 4.]],
         [[2., 2., 2., 2.],
          [2., 2., 2., 2.],
          [2., 2., 2., 2.]]]])
>>> d4
tensor([[[[6., 6., 6., 6.],
          [6., 6., 6., 6.]],
         [[3., 3., 3., 3.],
          [3., 3., 3., 3.]]]])
>>> c.shape
torch.Size([1, 2, 2, 3, 4])
>>> d1.shape
torch.Size([2, 2, 3, 4])
>>> d2.shape
torch.Size([1, 2, 3, 4])
>>> d3.shape
torch.Size([1, 2, 3, 4])
>>> d4.shape
torch.Size([1, 2, 2, 4])

dim=0,去掉一个括号,里面的内容相加。dim=1,去掉两个括号,里面的内容相加。
从上面实验可以看出,除了dim=0会直接去掉一个括号,其余的dim都会相加减。实际上dim=0也会相加,只是不便于理解

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值