torch.stack()
沿指定维度堆叠的所有张量
torch.stack(seq, dim=0, *, out=None)
oseq
(张量序列) - 要连接的张量序列。所有张量必须具有相同的形状。dim
(int,可选)- 张量将沿其连接的维度。默认值为 0。out
(张量,可选) - 输出张量- dim=0
-
import torch # create three 1-dimensional tensors of length 3 x = torch.tensor([1, 2, 3]) y = torch.tensor([4, 5, 6]) z = torch.tensor([7, 8, 9]) # stack the tensors along the first dimension stacked = torch.stack([x, y, z],dim=0) # print the stacked tensor print(stacked)
tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
dim=1
import torch # create three 1-dimensional tensors of length 3 x = torch.tensor([1, 2, 3]) y = torch.tensor([4, 5, 6]) # stack the tensors along the second dimension stacked = torch.stack([x, y], dim=1) # print the stacked tensor print(stacked)
tensor([[1, 4], [2, 5], [3, 6]])
torch.squeeze()
用于从张量中删除大小为 1 的任何维度。它将张量作为输入,并返回一个新的张量,删除了大小为 1 的所有维度。
函数
torch.squeeze(input, dim=None, *, out=None)
input
(张量) - 输入张量。dim
(整数或整数元组,可选)- 要压缩的维度。如果未指定,则将删除大小为 1 的所有尺寸。out
(张量,可选) - 输出张量。
torch.squeeze()例子1
import torch
# create a 1x3x1 tensor
x = torch.tensor([[[1], [2], [3]]])
# remove the dimension of size 1 using torch.squeeze()
y = torch.squeeze(x)
# print the shapes of the tensors
print(x.shape) # (1, 3, 1)
print(y.shape) # (3,)
torch.Size([1, 3, 1])
torch.Size([3])
torch.squeeze()例子2
import torch
# create a 1x3x1 tensor
x = torch.tensor([[[1], [2], [3]]])
# remove the second dimension using torch.squeeze()
y = torch.squeeze(x, dim=1)
# print the shapes of the tensors
print(x.shape) # (1, 3, 1)
print(y.shape) # (1, 3)
torch.Size([1, 3, 1])
torch.Size([1, 3])