1、torch.stack
torch.stack() 函数用于沿着新的维度(堆叠维度)对输入张量序列进行连接。这个函数将一系列张量堆叠在一起以创建一个新的张量。以下是 torch.stack() 函数的一般语法:
torch.stack(tensors, dim=0, out=None)
"""
tensors:张量列表,包含要堆叠的输入张量。
dim:指定要沿着哪个维度进行堆叠。默认值是 0。
out:这是一个可选参数,用于指定输出张量的位置。
"""
2、限制
- 输入张量的形状需一致:张量的形状在除了指定的维度dim之外,其他的维度需要保持一致。
- 至少需要两个输入张量:至少两个张量作为输入进行堆叠操作。如果只有一个张量,无法进行堆叠操作。
- 输出张量的形状+1:输出张量将在指定的维度上增加一个维度。
- 数据类型一致性:输入张量的数据类型应该保持一致,否则在堆叠时可能会出现数据类型不匹配的情况。
3、例子
3.1、shape的变化
import torch
x1 = torch.randn(2,3,1)
x2 = torch.randn(2,3,1)
x3 = torch.randn(2,3,1)
x4 = torch.randn(2,3,1)
stack_tensor_0 = torch.stack([x1,x2,x3,x4],dim=0)
stack_tensor_1 = torch.stack([x1,x2,x3,x4],dim=1)
stack_tensor_2 = torch.stack([x1,x2,x3,x4],dim=2)
stack_tensor_3 = torch.stack([x1,x2,x3,x4],dim=3)
print('stack_tensor_0.shape:',stack_tensor_0.shape)# torch.Size([4, 2, 3, 1])
print('stack_tensor_1.shape:',stack_tensor_1.shape)# torch.Size([2, 4, 3, 1]
print('stack_tensor_2.shape:',stack_tensor_2.shape)# torch.Size([2, 3, 4, 1])
print('stack_tensor_3.shape:',stack_tensor_3.shape)# torch.Size([2, 3, 1, 4])
3.2、dim = 0
x = torch.tensor([1, 2, 3]) # torch.Size([3])
y = torch.tensor([4, 5, 6]) # torch.Size([3])
stacked_tensor = torch.stack([x, y], dim=0)
print('stacked_tensor',stacked_tensor)
print('stacked_tensor_shape',stacked_tensor.shape)
stacked_tensor tensor([[1, 2, 3],
[4, 5, 6]])
stacked_tensor_shape torch.Size([2, 3])
3.3、dim = 1
x = torch.tensor([1, 2, 3]) # torch.Size([3])
y = torch.tensor([4, 5, 6]) # torch.Size([3])
stacked_tensor = torch.stack([x, y], dim=1)
print('stacked_tensor',stacked_tensor)
print('stacked_tensor_shape',stacked_tensor.shape)
stacked_tensor tensor([[1, 4],
[2, 5],
[3, 6]])
stacked_tensor_shape torch.Size([3, 2])