用于去掉维数为1的维度
输入
import torch
a=torch.tensor([[[1,2,3],
[4,5,6]]])
print(a)
print(a.shape)
结果
tensor([[[1, 2, 3],
[4, 5, 6]]])
torch.Size([1, 2, 3])
输入
b=a.squeeze()
print(b)
print(b.shape)
输出:将torch.Size([1, 2, 3])中的1维去掉了
tensor([[1, 2, 3],
[4, 5, 6]])
torch.Size([2, 3])
下面试验一下shape里面有两个维数1,是否可以全部去掉
输入
import torch
a=torch.randn(1,2,1,3)
print(a)
print(a.shape)
b=a.squeeze()
print(b)
print(b.shape)
输出
tensor([[[[ 0.4250, 0.7191, 0.4334]],
[[-1.4933, -0.5805, -0.4109]]]])
torch.Size([1, 2, 1, 3])
tensor([[ 0.4250, 0.7191, 0.4334],
[-1.4933, -0.5805, -0.4109]])
torch.Size([2, 3])
是否可以制定去掉哪个维度
torch.squeeze(input, dim=None, *, out=None)
input是需要操作的tensor数据类型,dim指定需要去掉的维度
也可以是
input.squeeze(dim=None, *, out=None)
输入:
import torch
a=torch.randn(1,2,1,3)
print(a)
print(a.shape)
b=a.squeeze(0)
print(b)
print(b.shape)
输出
tensor([[[[ 0.0865, 1.7813, -0.2721]],
[[ 0.2280, 0.7990, 0.5873]]]])
torch.Size([1, 2, 1, 3])
tensor([[[ 0.0865, 1.7813, -0.2721]],
[[ 0.2280, 0.7990, 0.5873]]])
torch.Size([2, 1, 3])
如果指定去掉的维度不是1维,会怎样
输入
import torch
a=torch.randn(1,2,1,3)
print(a)
print(a.shape)
b=a.squeeze(1)
print(b)
print(b.shape)
输出 不会发生变化
tensor([[[[ 1.2664, -0.7684, -1.3079]],
[[-0.5082, 2.2029, 0.1442]]]])
torch.Size([1, 2, 1, 3])
tensor([[[[ 1.2664, -0.7684, -1.3079]],
[[-0.5082, 2.2029, 0.1442]]]])
torch.Size([1, 2, 1, 3])