squeeze的用法主要就是对数据的维度进行压缩或者解压。
先看torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度,比如是一行或者一列这种,一个一行三列(1,3)的数去掉第一个维数为一的维度之后就变成(3)行。squeeze(a)就是将a中所有为1的维度删掉。不为1的维度没有影响。a.squeeze(N) 就是去掉a中指定的维数为一的维度。还有一种形式就是b=torch.squeeze(a,N) a中去掉指定的定的维数为一的维度。
再看torch.unsqueeze()这个函数主要是对数据维度进行扩充。给指定位置加上维数为一的维度,比如原本有个三行的数据(3),在0的位置加了一维就变成一行三列(1,3)。a.squeeze(N) 就是在a中指定位置N加上一个维数为1的维度。还有一种形式就是b=torch.squeeze(a,N) a就是在a中指定位置N加上一个维数为1的维度
在看许多pytorch的代码时,为了计算上的方便,通常会用到unsqueeze函数,一直不得要领,这次专门去做个实验学习一下。
官方文档对这个函数描述如下,就是在指定的位置插入一个维度,有两个参数,input是输入的tensor,dim是要插到的维度
需要注意的是dim的范围是[-input.dim()-1, input.dim()+1),是一个左闭右开的区间,当dim为负值时,会自动转换为dim = dim+input.dim()+1,类似于使用负数对python列表进行切片。
参考:https://blog.csdn.net/ljwwjl/article/details/115342632
import torch
x=torch.tensor([1,2,3,4])
y=torch.unsqueeze(x,0) #就是在指定的位置插入一个维度
print(y)
z=torch.unsqueeze(x,1)
print(z)
# tensor([[1, 2, 3, 4]])
# tensor([[1],
# [2],
# [3],
# [4]])
#下面使用一个二维矩阵看下dim不同时呈现出的效果: 创建一个3*4的全1二维tensor
a=torch.ones(3,4)
# tensor([[1., 1., 1., 1.],
# [1., 1., 1., 1.],
# [1., 1., 1., 1.]])
#在0维度上插入一个维度,可以看到现在a的形状变为[1, 3, 4],第0维度的大小默认是1
a = a.unsqueeze(0)
print(a)
print(a.shape)
# tensor([[[1., 1., 1., 1.],
# [1., 1., 1., 1.],
# [1., 1., 1., 1.]]])
# torch.Size([1, 3, 4])
#在最后一个维度上插入一个维度,形状变为[3, 4, 1]
a = a.unsqueeze(a.dim())
print(a.shape)