torch.squeeze用法
torch.squeeze(input, dim=None,out=None)
-
当不给定任何维度时
这个函数用于将输入的张量中所有维度的大小为1的维度全部去掉(这边描述的可能有点绕,看下面示例就会很清晰) -
当指定维度时
如果指定的维度,大小为1,则删除该维度,如果大小不为1则保留
参数
- input(Tensor) 输入的张量
- dim(int) 指定删除的维度,如果不给定,就要删除所有维度为1的维度
- out(Tensor) 得到的输出张量,(可写可不写,因为我们一般都用y=x.squeeze())
示例
x = torch.zeros(2, 1, 2, 1, 2)
x.size()
>>>torch.Size([2, 1, 2, 1, 2])
y = torch.squeeze(x) # 不指定维度,删除所有维度为1的维度
y.size()
>>>torch.Size([2, 2, 2])
y = torch.squeeze(x, 0) # 指定维度0,但是维度0的size不为1,所以保留
y.size()
>>>torch.Size([2, 1, 2, 1, 2])
y = torch.squeeze(x, 1) # 指定维度1,维度1的size为1,所以删除
y.size()
>>>torch.Size([2, 2, 1, 2])