squeeze:压缩,要减少维度。
unsqueeze:解压缩,要增加维度。
torch.squeeze(input),那么会把input中所有维度长度为1的维度去掉。
torch.squeeze(input,dim=1),那么在给定dim的情况下,就只去掉dim这个维度,其他维度还保留。
import torch
x = torch.rand(5,3)
x = x.squeeze(1)
tensor([[0.0621, 0.2074, 0.5420],
[0.5897, 0.3664, 0.4387],
[0.0115, 0.3464, 0.0702],
[0.7800, 0.4727, 0.1952],
[0.6879, 0.8595, 0.3933]])
这时候x的形状还是5行3列。因为没有哪个维度的长度为1。
x = x.unsqueeze(1)
tensor([[[0.0621, 0.2074, 0.5420]],
[[0.5897, 0.3664, 0.4387]],
[[0.0115, 0.3464, 0.0702]],
[[0.7800, 0.4727, 0.1952]],
[[0.6879, 0.8595, 0.3933]]])
那么x的形状是(5,1,3),有5个块,每个块是1行3列。
对于unsquueze来讲,维度可以比原有维度高1。例如最开始x的形状是(5,3)。可以如下操作。
import torch
x = torch.rand(5,3)
x = x.unsqueeze(2)
tensor([[[0.3757],
[0.8054],
[0.0250]],
[[0.9423],
[0.5109],
[0.2437]],
[[0.6276],
[0.4251],
[0.3276]],
[[0.6699],
[0.0768],
[0.3541]],
[[0.6123],
[0.0268],
[0.4193]]])
那么得到的tensor形状是(5,3,1)。
还是看你想要什么样的形状。