squeeze是降维,unsqueeze是升维。
那升降维是啥意思呢,找一个二维数组(tensor)当例子
import torch
x = torch.tensor([[1, 2, 3],[1, 2, 3],[1, 2, 3]])
print(torch.unsqueeze(x,0))
输出结果为:
tensor([[[1, 2, 3],
[1, 2, 3],
[1, 2, 3]]])
原来是个三行三列的数组,记为(3,3)
可见整体升了一维,加了一个中括号(1,3,3)
print(torch.unsqueeze(x,1))
输出结果为:
tensor([[[1, 2, 3]],
[[1, 2, 3]],
[[1, 2, 3]]])
原来的第一维升了一维,(3,1,3)
print(torch.unsqueeze(x,2))
输出结果为:
tensor([[[1],
[2],
[3]],
[[1],
[2],
[3]],
[[1],
[2],
[3]]])
原来的第一维升了一维,(3,3,1)
squeeze是降维,和unsqueeze是相反的操作
print(torch.squeeze(torch.unsqueeze(x,2),2))
输出结果为:
tensor([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]])
torch.unsqueeze(x,2)的结果为(3,3,1),torch.squeeze(input,2)发现它的第二维里面有一维的就把一维的取消掉,也就是中括号去掉。
print(torch.squeeze(torch.unsqueeze(x,2),1))
则不改变,因为torch.unsqueeze(x,2)得第一维里面是二维 ,所以不改变。只消去指定维里维数为1的。