torch.squeeze(input: Tensor, dim=None)
不指定维度时,删除所有大小为1的维度
- 指定维度的大小为1,删除指定维度
- 指定的维度大小不为1,不做任何改变
示例
import torch
torch.manual_seed(0)
tensor_ = torch.randn(20).reshape(1, 2, 1, 2, 1, 5)
tensor_A = torch.squeeze(tensor_)
tensor_B = torch.squeeze(tensor_, dim=2)
tensor_C = torch.squeeze(tensor_, dim=1)
# 大小为1的维度全没了
print(tensor_A.shape) # torch.Size([2, 2, 5])
# 指定维度,且指定维度大小为1的维度没了
print(tensor_B.shape) # torch.Size([1, 2, 2, 1, 5])
# 指定维度,指定维度大小不为1,不做任何改变
print(tensor_C.shape) # torch.Size([1, 2, 1, 2, 1, 5])