关于python的torch.squeeze() 和torch.unsqueeze()
torch.squeeze() 是降维
torch.unsqueeze() 是升维
torch.squeeze() 针对1进行处理,非1不处理,指定是指定位置,不指定去掉所有1
torch.unsqueeze()在指定位置加1维,不指定报错
例1
>>> import torch
>>> x=torch.zeros(2,1,2,1,3)
>>> x.size()
torch.Size([2, 1, 2, 1, 3])
>>> y=torch.squeeze(x,3)
>>> y.size()
torch.Size([2, 1, 2, 3])
>>> y=torch.squeeze(x,4)
>>> y.size()
torch.Size([2, 1, 2, 1, 3])
例2
>>> z=torch.zeros(2,1,4,8,1,3)
>>> z.size()
torch.Size([2, 1, 4, 8, 1, 3])
>>> y1=torch.squeeze(z)
>>> y1.size()
torch.Size([2, 4, 8, 3])
>>> y2=torch.squeeze(z,0)
>>> y2.size()
torch.Size([2, 1, 4, 8, 1, 3])
>>> y2=torch.squeeze(z,1)
>>> y2.size()
torch.Size([2, 4, 8, 1, 3])
>>> y3=torch.squeeze(z,4)
>>> y3.size()
torch.Size([2, 1, 4, 8, 3])
>>> y4=torch.squeeze(z,5)
>>> y4.size()
torch.Size([2, 1, 4, 8, 1, 3])
>>> y4=torch.squeeze(z,3)
>>> y4.size()
torch.Size([2, 1, 4, 8, 1, 3])
>>> y4=torch.squeeze(z,2)
>>> y4.size()
torch.Size([2, 1, 4, 8, 1, 3])
>>> y6=torch.squeeze(z,2)
>>> y6.size()
torch.Size([2, 1, 4, 8, 1, 3])
>>> y6=torch.squeeze(z,2)
>>> y6.size()
torch.Size([2, 1, 4, 8, 1, 3])
>>> y6=torch.unsqueeze(z,0)
>>> y6.size()
torch.Size([1, 2, 1, 4, 8, 1, 3])
>>> y7=torch.unsqueeze(z,1)
>>> y7.size()
torch.Size([2, 1, 1, 4, 8, 1, 3])
>>> y8=torch.unsqueeze(z,-1)
>>> y8.size()
torch.Size([2, 1, 4, 8, 1, 3, 1])
>>> y9=torch.unsqueeze(z,-0)
>>> y9.size()
torch.Size([1, 2, 1, 4, 8, 1, 3])
>>>
>>> y9=torch.unsqueeze(z,2)
>>> y9.size()
torch.Size([2, 1, 1, 4, 8, 1, 3])
>>> y9=torch.unsqueeze(z,3)
>>> y9.size()
torch.Size([2, 1, 4, 1, 8, 1, 3])
torch.squeeze()多写变量会报错
y6=torch.squeeze(z,2,2)
Traceback (most recent call last):
File “”, line 1, in
TypeError: squeeze() received an invalid combination of arguments - got (Tensor, int, int), but expected one of:* (Tensor input) * (Tensor input, int dim)
torch.unsqueeze()少写指定位置会报错
y8=torch.unsqueeze(z)
Traceback (most recent call last):
File “”, line 1, in
TypeError: unsqueeze() missing 1 required positional arguments: “dim”