torch.stack()与torch.cat() 用法

正文

在pytorch中,常见的拼接函数主要是两个,分别是:

  1. stack()
  2. cat()

stack可以保留两个信息[1.序列]和[2.张量矩阵]信息,先扩张再拼接。

cat()用于拼接多个tensor。

实际使用中两者使用场景不同。

torch.cat() 和python中的内置函数cat(), 在使用和目的上,是没有区别的,区别在于前者操作对象是tensor。
 

1.stack()函数

用法

官方解释:沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状。

浅显说法:把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠

outputs = torch.stack(tensor(sequence), dim(int)) → Tensor

参数:

  • tensors(sequence): 需要连接的张量序列。
  • dim(int):新的维度,在第dim个维度上拼接,0 <= dim < len(outputs)

注意:

  1. python的序列数据只有listtuple
  2. len(outputs)指生成数据的维度大小,即outputs维度值

重点:

  1. 函数输入tensor(sequence)只能是序列;且序列内部的张量元素,必须shape完全一致。举例:[tensor_1, tensor_2,..]或者(tensor_1, tensor_2,..),且必须tensor_1.shape == tensor_2.shape
  2. dim是选择生成的维度,必须满足0<=dim<len(outputs)len(outputs)是输出后的tensor的维度大小。

总结:

  • 作用:函数stack()序列数据内部的张量进行扩维拼接,指定维度由程序员选择、大小是生成后数据的维度区间。
  • 使用意义:通常为了保留–[序列(先后)信息] 和 [张量的矩阵信息] 才会使用stack。

举例

1.准备2个tensor数据,每个的shape都是[3,3]

# 假设是时间步T1的输出
T1 = torch.tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
# 假设是时间步T2的输出
T2 = torch.tensor([[10, 20, 30],[40, 50, 60],[70, 80, 90]])

 2.测试stack函数

print(torch.stack((T1,T2),dim=0).shape)
print(torch.stack((T1,T2),dim=1).shape)
print(torch.stack((T1,T2),dim=2).shape)
print(torch.stack((T1,T2),dim=3).shape)
# outputs:
torch.Size([2, 3, 3])
torch.Size([3, 2, 3])
torch.Size([3, 3, 2])
'选择的dim>len(outputs),所以报错'
IndexError: Dimension out of range (expected to be in range of [-3, 2], but got 3)

拼接后的tensor形状,会根据不同的dim发生变化。

dimshape
0[2, 3, 3]
1[3, 2, 3]
2[3, 3, 2]
3溢出报错

2.cat()函数

用法

函数目的:在给定维度上对输入的张量序列seq 进行连接操作。

outputs = torch.cat(inputs, dim=?) → Tensor

outputs = torch.cat(inputs, dim=?) → Tensor

参数:

  • inputs : 待连接的张量序列,可以是任意相同Tensor类型的python 序列
  • dim : 选择的扩维, 必须在0到len(inputs[0])之间,沿着此维连接张量序列。

重点:

  1. 输入数据必须是序列,序列中数据是任意相同的shape的同类型tensor
  2. 维度不可以超过输入数据的任一个张量的维度

总结:

  • torch.cat()主要用于tensor拼接。

举例

1.准备数据,每个shape都是[2,3]。

# x1
x1 = torch.tensor([[11,21,31],[21,31,41]],dtype=torch.int)
x1.shape # torch.Size([2, 3])
# x2
x2 = torch.tensor([[12,22,32],[22,32,42]],dtype=torch.int)
x2.shape  # torch.Size([2, 3])

2.合成inputs

'inputs为2个形状为[2 , 3]的矩阵 '
inputs = [x1, x2]
print(inputs)
'打印查看'
[tensor([[11, 21, 31],
         [21, 31, 41]], dtype=torch.int32),
 tensor([[12, 22, 32],
         [22, 32, 42]], dtype=torch.int32)]

3.查看结果,测试不同dim拼接结果

In    [1]: torch.cat(inputs, dim=0).shape
Out[1]: torch.Size([4,  3])

In    [2]: torch.cat(inputs, dim=1).shape
Out[2]: torch.Size([2, 6])

In    [3]: torch.cat(inputs, dim=2).shape
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

补充

​C = torch.cat( (A,B),0 )  # 按维数0拼接(竖着拼)

C = torch.cat( (A,B),1 )  # 按维数1拼接(横着拼)

参考博客

torch.cat()函数的官方解释,详解以及例子_xinjieyuan的博客-CSDN博客_torch.cat

Pytorch中的torch.cat()函数 - 知乎 

Pytorch中的torch.cat()函数 - 不愿透漏姓名的王建森 - 博客园 

torch.stack()的官方解释,详解以及例子_xinjieyuan的博客-CSDN博客_torch.stack() 

  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值