x = torch.randn(3, 1, 2)
x
tensor([[[-0.1986, 0.4352]],
[[ 0.0971, 0.2296]],
[[ 0.8339, -0.5433]]])
x.squeeze().size() # 不加参数,去掉所有为元素个数为1的维度
torch.Size([3, 2])x.squeeze()
tensor([[-0.1986, 0.4352],
[ 0.0971, 0.2296],
[ 0.8339, -0.5433]])torch.squeeze(x, 0).size() # 加上参数,去掉第一维的元素,不起作用,因为第一维有2个元素
torch.Size([3, 1, 2])torch.squeeze(x, 1).size() # 加上参数,去掉第二维的元素,正好为 1,起作用
torch.Size([3, 2])