pytorch中squeeze()和unsqueeze()函数的作用
squeeze()的函数定义:
torch.squeeze(input, dim=None, out=None) → Tensor
返回一个张量,其中所有大小为1的输入的维都已删除。
举个例子,如果输入张量的shape为(A×1×B×C×1×D) ,那么输出张量的shape是(A×B×C×D) .
如果指定了dim,则仅在给定维度上执行挤压操作。如果输入的形状为:(A×1×B),则squeeze(input,0)保持张量不变,但squeeze(input,1)会将张量压缩为形状(A×B)。本身的size没有发生改变。
例子:
x = torch.zeros(2,2,1,2,1,2)
x.size()
Out[10]: torch.Size([2, 2, 1, 2, 1, 2])
y = torch.squeeze(x)
y.size()
Out[12]: torch.Size([2, 2, 2, 2])
y = torch.squeeze(x,0)
y.size()
Out[14]: torch.Size([2, 2, 1, 2, 1, 2])
y = torch.squeeze(x,2)
y.size()
Out[16]: torch.Size([2, 2, 2, 1, 2])
unsqueeze()的函数定义:
torch.unsqueeze(input, dim, out=None) → Tensor
返回在指定位置插入尺寸为1的新张量。
x = torch.tensor([1,2,3,4,5])
x
Out[21]: tensor([1, 2, 3, 4, 5])
x.size()
Out[22]: torch.Size([5])
torch.unsqueeze(x,0)
Out[23]: tensor([[1, 2, 3, 4, 5]])
x.size()
Out[24]: torch.Size([5])
b = torch.unsqueeze(x,0)
b.size()
Out[26]: torch.Size([1, 5])
torch.unsqueeze(x,1)
Out[27]:
tensor([[1],
[2],
[3],
[4],
[5]])
c = torch.unsqueeze(x,1)
c.size()
Out[29]: torch.Size([5, 1])