torch.cat()与torch.stack()——数组的拼接

PyTorch学习笔记:torch.cat与torch.stack——数组的拼接

torch.cat()

torch.cat(tensors, dim=0, *, out=None) → Tensor

官方解释:利用给定的维度连接给定的数组序列(cat代表concatenate),所有数组必须具有相同的形状(连接维度除外)或为空。
相当于按指定维度将数组进行拼接

参数解释:

  • tensors:要连接的数组序列(元组tuple或者列表list)
  • dim:数组连接的维度
  • out:输出数组(一般用不到,如果有输出,则可以直接进行赋值操作)

注意:
tensors输入的必须是数组序列,不能是单个数组;
②输入的数组序列除了dim维度,其他维度必须形状相同

举例:

import torch
a=torch.arange(6).reshape(2,3)
b=torch.arange(12)
c=torch.cat((a,b.reshape(4,3)),dim=0)
# 沿第0维度进行拼接,也就是按行拼接(竖着拼)
d=torch.cat((a,b.reshape(2,6)),dim=1)
# 沿第1维度进行拼接,也就是按列拼接(横着拼)
print(c)
print(c.shape)
print(d)
print(d.shape)

输出:

tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
torch.Size([6, 3])
tensor([[ 0,  1,  2,  0,  1,  2,  3,  4,  5],
        [ 3,  4,  5,  6,  7,  8,  9, 10, 11]])
torch.Size([2, 9])

利用torch.cat()沿dim拼接,在形状上看相当于对dim进行相加,其余维度大小不变,利用这个思想,可以很容易理解高维数组的拼接

高维举例:

import torch
a=torch.ones(4*256*56*56).reshape(4,256,56,56)
b=torch.arange(4*128*56*56).reshape(4,128,56,56)
c=torch.zeros(4*64*56*56).reshape(4,64,56,56)
d=torch.cat((a,b,c),dim=1)
print(d.shape)

输出:

torch.Size([4, 448, 56, 56])

上述例子在卷积神经网络中常用于特征图的堆叠

torch.stack()

torch.stack(tensors, dim=0, *, out=None) → Tensor

官方解释:沿着新的维度连接一系列数组,所有的数组都需要具有相同的大小。
相当于先将多个n维数组进行扩维操作,然后再拼接为一个n+1维的数组

参数解释:

  • tensors:要连接的数组序列(元组tuple或者列表list)
  • dim:要插入的维度,大小必须介于0和需要拼接的数组维数之间(dim最大不超过数组的维数)
  • out:输出数组(与cat()类似,一般用不到)

注意:
①与cat类似,必须输入数组序列,不能是单个数组;
②输入的所有数组序列形状(尺寸)必须一致(这里与cat有区别)。

举例:

import torch
a=torch.arange(12).reshape(3,4)
b=torch.ones(12).reshape(3,4)
c=torch.stack((a,b),dim=0)
d=torch.stack((a,b),dim=1)
e=torch.stack((a,b),dim=2)
# dim最大可到输入数组的维数,即a、b的维数
print(c)
print(c.shape)
print(d)
print(d.shape)
print(e)
print(e.shape)

输出:

tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[ 1.,  1.,  1.,  1.],
         [ 1.,  1.,  1.,  1.],
         [ 1.,  1.,  1.,  1.]]])
torch.Size([2, 3, 4])
tensor([[[ 0.,  1.,  2.,  3.],
         [ 1.,  1.,  1.,  1.]],

        [[ 4.,  5.,  6.,  7.],
         [ 1.,  1.,  1.,  1.]],

        [[ 8.,  9., 10., 11.],
         [ 1.,  1.,  1.,  1.]]])
torch.Size([3, 2, 4])
tensor([[[ 0.,  1.],
         [ 1.,  1.],
         [ 2.,  1.],
         [ 3.,  1.]],

        [[ 4.,  1.],
         [ 5.,  1.],
         [ 6.,  1.],
         [ 7.,  1.]],

        [[ 8.,  1.],
         [ 9.,  1.],
         [10.,  1.],
         [11.,  1.]]])
torch.Size([3, 4, 2])

仔细观察上个案例维度的变化,可以发现当输入为两组数组时,dim定为几,拼接后哪个维度就是2(有两个输入数组),相当于做了一个扩维拼接操作。首先按dim增加一个维度,然后再从该维度上进行拼接操作。

cat与stack的区别

torch.cat()是直接在原数组数据上进行拼接,不会改变维数大小;torch.stack首先进行扩维,然后再进行拼接,会将维数增大一个单位。

官方文档

torch.cat():https://pytorch.org/docs/stable/generated/torch.cat.html
torch.stack():https://pytorch.org/docs/stable/generated/torch.stack.html

点个赞支持一下吧

### 回答1: torch.cat() 和 torch.stack() 都是 PyTorch 中的 Tensor 操作函数,用于对 Tensor 进行拼接和堆叠。 torch.cat() 用于对 Tensor 进行按维度拼接。例如,如果你有三个形状为 (2, 3) 的 Tensor,你可以使用 torch.cat() 将它们拼接成形状为 (6, 3) 的 Tensor。 torch.stack() 用于对 Tensor 进行按维度堆叠。例如,如果你有三个形状为 (2, 3) 的 Tensor,你可以使用 torch.stack() 将它们堆叠成形状为 (3, 2, 3) 的 Tensor。 ### 回答2: PyTorch是当前机器学习和深度学习领域最流行的框架之一。torch.cat()和torch.stack()是PyTorch中最常用的两个函数之一,它们都可以将多个张量进行拼接torch.cat()函数可以对给定的多个张量进行维度拼接,也就是将输入张量按照给定维度进行连接。这个函数最常用的情况是在将两个张量按照某个维度进行连接,也可以将多个张量按照某个维度进行连接。举个例子,如果有两个张量A和B,且它们的第一个维度相同(比如都是5),那么可以使用torch.cat([A,B],dim=1)将它们按照第二个维度进行拼接torch.stack()函数与torch.cat()函数非常相似,也可以将多个张量按照给定维度进行拼接,但是它还有一些额外的功能。除了可以按照给定的维度拼接张量之外,torch.stack()函数还会在拼接后的结果中添加一个新的维度。这个新增的维度会被插入到给定维度的位置。举个例子,如果有两个张量A和B,每个张量的形状都是5*3,使用torch.stack([A,B])将它们按照新的维度进行拼接,那么拼接后的张量形状就会变成2*5*3,其中新增加的维度是1. 总之,torch.cat()和torch.stack()都是非常实用的函数,可以帮助我们将多个张量按照给定维度进行拼接,并且在更高维度中增加新的张量。需要注意的是,torch.cat()和torch.stack()两个函数的使用方法略有不同,选择使用哪个函数需要视具体情况而定。 ### 回答3: torch.cat()和torch.stack()都是PyTorch中用于拼接多个张量的函数,但它们在使用方式和结果上有所不同。 首torch.cat()的使用方法如下: ```python torch.cat(tensors, dim=0, out=None) -> Tensor ``` 其中,tensors是需要拼接的张量的列表,dim指定拼接的维度,默认为0,out指定输出的Tensor,如果不指定,则会创建一个新的Tensor来存储结果。 例如,我们有两个2×3的张量a和b,想将它们按行拼接起来,代码如下: ```python a = torch.tensor([[1, 2, 3], [4, 5, 6]]) b = torch.tensor([[7, 8, 9], [10, 11, 12]]) c = torch.cat([a, b], dim=0) print(c) ``` 输出结果为: ``` tensor([[ 1, 2, 3], [ 4, 5, 6], [ 7, 8, 9], [10, 11, 12]]) ``` 可以看到,c就是将a和b按行拼接起来得到的结果。 而torch.stack()的使用方法如下: ```python torch.stack(tensors, dim=0, out=None) -> Tensor ``` 其中,tensors是需要拼接的张量的列表,dim指定拼接的维度,默认为0,out指定输出的Tensor,如果不指定,则会创建一个新的Tensor来存储结果。 不同于cat函数,stack函数会创建一个新的维度来存储拼接的结果。例如,我们有两个2×3的张量a和b,想将它们堆叠起来,代码如下: ```python a = torch.tensor([[1, 2, 3], [4, 5, 6]]) b = torch.tensor([[7, 8, 9], [10, 11, 12]]) c = torch.stack([a, b], dim=0) print(c) ``` 输出结果为: ``` tensor([[[ 1, 2, 3], [ 4, 5, 6]], [[ 7, 8, 9], [10, 11, 12]]]) ``` 可以看到,c是将a和b在一个新的维度上堆叠起来得到的结果。 总之,torch.cat()适用于沿着一个已有的维度拼接张量,而torch.stack()则适用于在新的维度上堆叠张量。两者之间的选择应根据实际需求进行。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

视觉萌新、

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值