将size=1的维度去掉,默认去掉全为size=1,可以指定具体的
应用
>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x)
>>> y.size()
torch.Size([2, 2, 2])
>>> y = torch.squeeze(x, 0)
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x, 1)
>>> y.size()
torch.Size([2, 2, 1, 2])
API
torch.squeeze(input, dim=None, out=None) → Tensor
参考 | 描述 |
---|---|
input (Tensor) | the input tensor. |
dim (int, optional) | if given, the input will be squeezed only in this dimension |
out (Tensor, optional) | the output tensor. |
参考:
https://pytorch.org/docs/stable/generated/torch.squeeze.html#torch.squeeze