顾名思义:unsqueeze,扩展维度,返回一个新的张量,对输入的既定位置插入维度 1
squeeze,压缩维度,将输入张量形状中的1 去除并返回。
torch.unsqueeze(input, dim)
torch.squeeze(input, dim)
tensor (Tensor)
– 输入张量dim (int)
– 插入/消除 维度的索引
以下用一个二维张量进行举例:
压缩维度仅对(0,1)索引进行示例,(-1,-2)原理类似
import torch
x = torch.Tensor([[1, 2, 3, 4],
[5,6,7,8]])
print('#' * 50)
print(x)
print(x.size())
print(x.dim())
##########
print('#' * 50)
print(torch.unsqueeze(x, 0))
print(torch.unsqueeze(x, 0).size())
print(torch.unsqueeze(x, 0).dim())
m=torch.unsqueeze(x, 0)
print(m.squeeze(0))
n=m.squeeze(0)
print(n.size())
print(n.dim())
##########
print('#' * 50)
print(torch.unsqueeze(x, 1))
print(torch.unsqueeze(x, 1).size())
print(torch.unsqueeze(x, 1).dim())
a=torch.unsqueeze(x, 1)
print(a.squeeze(1))
b=a.squeeze(1)
print(b.size())
print(b.dim())
##########
print('#' * 50)
print(torch.unsqueeze(x, -1))
print(torch.unsqueeze(x, -1).size())
print(torch.unsqueeze(x, 1).dim())
##########
print('#' * 50)
print(torch.unsqueeze(x, -2))
print(torch.unsqueeze(x, -2).size())
print(torch.unsqueeze(x, -2).dim())
相应结果:
##################################################
tensor([[1., 2., 3., 4.],
[5., 6., 7., 8.]])
torch.Size([2, 4])
2
##################################################
tensor([[[1., 2., 3., 4.],
[5., 6., 7., 8.]]])
torch.Size([1, 2, 4])
3
tensor([[1., 2., 3., 4.],
[5., 6., 7., 8.]])
torch.Size([2, 4])
2
##################################################
tensor([[[1., 2., 3., 4.]],
[[5., 6., 7., 8.]]])
torch.Size([2, 1, 4])
3
tensor([[1., 2., 3., 4.],
[5., 6., 7., 8.]])
torch.Size([2, 4])
2
##################################################
tensor([[[1.],
[2.],
[3.],
[4.]],
[[5.],
[6.],
[7.],
[8.]]])
torch.Size([2, 4, 1])
3
##################################################
tensor([[[1., 2., 3., 4.]],
[[5., 6., 7., 8.]]])
torch.Size([2, 1, 4])
3