深度学习中小知识点系列(十八) 解析pytorch中torch.stack()和torch.cat()(详尽)

在pytorch中,常见的拼接函数主要是两个,分别是:

  1. stack()
  2. cat()

torch.stack()

函数的意义:使用stack可以保留两个信息:[1. 序列] 和 [2. 张量矩阵] 信息,属于【扩张再拼接】的函数。

形象的理解:假如数据都是二维矩阵(平面),它可以把这些一个个平面按第三维(例如:时间序列)压成一个三维的立方体,而立方体的长度就是时间序列长度。该函数常出现在自然语言处理(NLP)和图像卷积神经网络(CV)中。

1 stack()

官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。

浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠。

outputs = torch.stack(inputs, dim=?) → Tensor

参数

inputs : 待连接的张量序列。
注:python的序列数据只有list和tuple。

dim : 新的维度, 必须在0到len(outputs)之间。
注:len(outputs)是生成数据的维度大小,也就是outputs的维度值。

2 重点

  1. 函数中的输入inputs只允许是序列;且序列内部的张量元素,必须shape相等
    ----举例:[tensor_1, tensor_2,…]或者(tensor_1, tensor_2,…),且必须tensor_1.shape == tensor_2.shape
  2. dim是选择生成的维度,必须满足0<=dim<len(outputs);len(outputs)是输出后的tensor的维度大小
    不懂的看例子,再回过头看就懂了。

3 例子

  1. 准备2个tensor数据,每个的shape都是[3,3]
# 假设是时间步T1的输出
T1 = torch.tensor([[1, 2, 3],
        		[4, 5, 6],
        		[7, 8, 9]])
# 假设是时间步T2的输出
T2 = torch.tensor([[10, 20, 30],
        		[40, 50, 60],
        		[70, 80, 90]])
T1:
 tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])
T2:
 tensor([[10, 20, 30],
        [40, 50, 60],
        [70, 80, 90]])
  1. 测试stack函数
R0 = torch.stack((T1, T2), dim=0)
print("R0:\n", R0)
print("R0.shape:\n", R0.shape)
"""
R0:
 tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[10, 20, 30],
         [40, 50, 60],
         [70, 80, 90]]])
R0.shape:
 torch.Size([2, 3, 3])
"""

R1 = torch.stack((T1, T2), dim=1)
print("R1:\n", R1)
print("R1.shape:\n", R1.shape)
"""
R1:
 tensor([[[ 1,  2,  3],
         [10, 20, 30]],

        [[ 4,  5,  6],
         [40, 50, 60]],

        [[ 7,  8,  9],
         [70, 80, 90]]])
R1.shape:
 torch.Size([3, 2, 3])

"""

R2 = torch.stack((T1, T2), dim=2)
print("R2:\n", R2)
print("R2.shape:\n", R2.shape)
"""
R2:
 tensor([[[ 1, 10],
         [ 2, 20],
         [ 3, 30]],

        [[ 4, 40],
         [ 5, 50],
         [ 6, 60]],

        [[ 7, 70],
         [ 8, 80],
         [ 9, 90]]])
R2.shape:
 torch.Size([3, 3, 2])

"""

R3 = torch.stack((T1, T2), dim=3)
print("R3:\n", R3)
print("R3.shape:\n", R3.shape)
"""
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)
"""

可以复制代码运行试试:拼接后的tensor形状,会根据不同的dim发生变化。

在这里插入图片描述

4 重点

  1. 函数作用:
    函数stack()对序列数据内部的张量进行扩维拼接,指定维度由程序员选择、大小是生成后数据的维度区间。
  2. 存在意义:
    在自然语言处理和卷及神经网络中, 通常为了保留–[序列(先后)信息] 和 [张量的矩阵信息] 才会使用stack。

函数存在意义?》》》

手写过RNN的同学,知道在循环神经网络中输出数据是:一个list,该列表插入了seq_len个形状是[batch_size, output_size]的tensor,不利于计算,需要使用stack进行拼接,保留–[1.seq_len这个时间步]和–[2.张量属性[batch_size, output_size]]。

torch.cat()

一般torch.cat()是为了把函数torch.stack()得到tensor进行拼接而存在的。torch.cat() 和python中的内置函数cat(), 在使用和目的上,是没有区别的,区别在于前者操作对象是tensor。

1 cat()

函数目的: 在给定维度上对输入的张量序列seq 进行连接操作。

outputs = torch.cat(inputs, dim=0) → Tensor

参数

  • inputs : 待连接的张量序列,可以是任意相同Tensor类型的python 序列。
  • dim : 选择的扩维, 必须在0到len(inputs[0])之间,沿着此维连接张量序列。

2 重点

  • 输入数据必须是序列,序列中数据是任意相同的shape的同类型tensor
  • 维度不可以超过输入数据的任一个张量的维度

3 例子

  1. 准备数据,每个的shape都是[2,3]
x1 = torch.tensor([[11, 21, 31], [21, 31, 41]], dtype=torch.int)
print("x1:\n", x1)
print("x1.shape:\n", x1.shape)
'''
x1:
 tensor([[11, 21, 31],
        [21, 31, 41]], dtype=torch.int32)
x1.shape:
 torch.Size([2, 3])
'''
x2 = torch.tensor([[12, 22, 32], [22, 32, 42]])
print("x2:\n", x2)
print("x2.shape:\n", x2.shape)
'''
x2:
 tensor([[12, 22, 32],
        [22, 32, 42]])
x2.shape:
 torch.Size([2, 3])
'''
  1. 合成inputs
inputs = [x1, x2]
print("inputs:\n", inputs)
'''
inputs:
 [tensor([[11, 21, 31],
        [21, 31, 41]], dtype=torch.int32), tensor([[12, 22, 32],
        [22, 32, 42]])]
'''
  1. 查看结果, 测试不同的dim拼接结果
R0 = torch.cat(inputs, dim=0)
print("R0:\n", R0)
print("R0.shape:\n", R0.shape)
'''
R0:
 tensor([[11, 21, 31],
        [21, 31, 41],
        [12, 22, 32],
        [22, 32, 42]])
R0.shape:
 torch.Size([4, 3])
'''

R1 = torch.cat(inputs, dim=1)
print("R1:\n", R1)
print("R1.shape:\n", R1.shape)
'''
R1:
 tensor([[11, 21, 31, 12, 22, 32],
        [21, 31, 41, 22, 32, 42]])
R1.shape:
 torch.Size([2, 6])
'''

R2 = torch.cat(inputs, dim=2)
print("R2:\n", R2)
print("R2.shape:\n", R2.shape)
'''
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
'''
  • 10
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小酒馆燃着灯

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值