torch.stack详细用法



1、torch.stack

  torch.stack() 函数用于沿着新的维度(堆叠维度)对输入张量序列进行连接。这个函数将一系列张量堆叠在一起以创建一个新的张量。以下是 torch.stack() 函数的一般语法:

torch.stack(tensors, dim=0, out=None)
"""
tensors:张量列表,包含要堆叠的输入张量。
dim:指定要沿着哪个维度进行堆叠。默认值是 0。
out:这是一个可选参数,用于指定输出张量的位置。
"""

2、限制

  • 输入张量的形状需一致:张量的形状在除了指定的维度dim之外,其他的维度需要保持一致。
  • 至少需要两个输入张量:至少两个张量作为输入进行堆叠操作。如果只有一个张量,无法进行堆叠操作。
  • 输出张量的形状+1:输出张量将在指定的维度上增加一个维度。
  • 数据类型一致性:输入张量的数据类型应该保持一致,否则在堆叠时可能会出现数据类型不匹配的情况。

3、例子

3.1、shape的变化

import torch

x1 = torch.randn(2,3,1)
x2 = torch.randn(2,3,1)
x3 = torch.randn(2,3,1)
x4 = torch.randn(2,3,1)

stack_tensor_0 = torch.stack([x1,x2,x3,x4],dim=0)
stack_tensor_1 = torch.stack([x1,x2,x3,x4],dim=1)
stack_tensor_2 = torch.stack([x1,x2,x3,x4],dim=2)
stack_tensor_3 = torch.stack([x1,x2,x3,x4],dim=3)


print('stack_tensor_0.shape:',stack_tensor_0.shape)# torch.Size([4, 2, 3, 1])
print('stack_tensor_1.shape:',stack_tensor_1.shape)# torch.Size([2, 4, 3, 1]
print('stack_tensor_2.shape:',stack_tensor_2.shape)# torch.Size([2, 3, 4, 1])
print('stack_tensor_3.shape:',stack_tensor_3.shape)# torch.Size([2, 3, 1, 4])

3.2、dim = 0

x = torch.tensor([1, 2, 3]) # torch.Size([3])
y = torch.tensor([4, 5, 6]) # torch.Size([3])
stacked_tensor = torch.stack([x, y], dim=0)
print('stacked_tensor',stacked_tensor)
print('stacked_tensor_shape',stacked_tensor.shape)
stacked_tensor tensor([[1, 2, 3],
        			  [4, 5, 6]])
stacked_tensor_shape torch.Size([2, 3])

3.3、dim = 1

x = torch.tensor([1, 2, 3]) # torch.Size([3])
y = torch.tensor([4, 5, 6]) # torch.Size([3])
stacked_tensor = torch.stack([x, y], dim=1)
print('stacked_tensor',stacked_tensor)
print('stacked_tensor_shape',stacked_tensor.shape)
stacked_tensor tensor([[1, 4],
                       [2, 5],
                       [3, 6]])
stacked_tensor_shape torch.Size([3, 2])
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值