1、torch.unsqueeze()
函数
语法:
torch.unsqueeze(input, dim=None)
参数:
input----输入张量
dim------在这个维度上插入一个新的维度
作用:
在指定位置插入一个新的维度
注意:
dim的范围是 [ -input.dim() -1, input.dim()+1 ),是一个左闭右开的区间,当dim为负值时,会自动转换为dim = dim+input.dim()+1
示例:
a = torch.ones(2, 3, 2) # 创建一个全1的三维张量
'''a输出结果:
tensor([[[1., 1.],
[1., 1.],
[1., 1.]],
[[1., 1.],
[1., 1.],
[1., 1.]]])
'''
a.size() # (2,3,2)
b = torch.unsqueeze(a, dim=1) # 在dim=1插入一个维度
# b=a.unsqueeze(dim=1) 另一种写法
'''b输出结果:
tensor([[[[1., 1.],
[1., 1.],
[1., 1.]]],
[[[1., 1.],
[1., 1.],
[1., 1.]]]])
'''
b.size() # (2,1,3,2)
经过unsqueeze(dim=1)函数之后,张量维度由(2,3,2)转换为(2,1,3,2),在dim=1处插入了一个维度。
之前经常看不对张量的size,在这里记录一个方法:
# 以上文的b为例:
tensor([[[[1., 1.],
[1., 1.],
[1., 1.]]],
[[[1., 1.],
[1., 1.],
[1., 1.]]]]) # b.size()=(2,1,3,2)
下面对照着中括号来看:
首先,最外边的中括号:无论多少维的张量,最外边都有一层中括号,所以最外边的中括号本身不包含任何维度信息;
第二层: 从垂直方向上看,有两层,说明该维度长度为2;
第三层: 由于上一层有两层,输出也将其分成了上下两块,只需要看第一块即可(后面的维度都是只看第一块),从垂直方向上看,有一层,说明该维度长度为1;
第四层: 从垂直方向上看,有3层,说明该维度长度为3;
第五层: 也是最内层,随便选择一个中括号,观察中括号内有两个元素,说明该维度为2。
所以,b的size为(2,1,3,2)
2、unsqueeze_()函数
unsqueeze_
和 unsqueeze
实现一样的功能,区别在于 unsqueeze_
是 in_place 操作,即 unsqueeze
不会对使用 unsqueeze
的 tensor 进行改变,想要获取 unsqueeze
后的值必须赋予个新值, unsqueeze_
则会对自己改变。
示例:
a = torch.Tensor([1, 2, 3, 4])
print(a)
# tensor([1., 2., 3., 4.])
b = torch.unsqueeze(a, 1)
print(b)
# tensor([[1.],
# [2.],
# [3.],
# [4.]])
print(a)
# tensor([1., 2., 3., 4.])
print(a.unsqueeze_(1))
# tensor([[1.],
# [2.],
# [3.],
# [4.]])
print(a)
# tensor([[1.],
# [2.],
# [3.],
# [4.]])
3、squeeze()函数
语法:
torch.squeeze(input, dim=None)
参数:
input----输入张量
dim------如果给定,只会在这个维度上对输入张量进行降维
作用:
当未给定dim时,squeeze
将输入张量形状中的1 去除并返回。 如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)
a = torch.zeros(1, 2, 1, 3, 1, 4)
b=torch.squeeze(a) # 未给定dim,squeeze对整个张量进行挤压操作,将张量中的1全部去除
b.size()
#输出: torch.Size([2, 3, 4])
当给定dim时,那么挤压操作只在给定维度上。例如,输入形状为: (A×1×B), squeeze(input, 0)
将会保持张量不变,只有用 squeeze(input, 1)
,形状会变成 (A×B)。
a = torch.zeros(1, 2, 1, 3, 1, 4)
c = torch.squeeze(a,0) # 当给定dim时,那么挤压操作只在给定维度上
c.size()
#输出: torch.Size([2, 1, 3, 1, 4])
d = torch.squeeze(a, 1) # a在dim=1上维度是2,形状不会发生变化
d.size()
#输出: torch.Size([1, 2, 1, 3, 1, 4])
参考:
https://zhuanlan.zhihu.com/p/86763381
PyTorch入门——数组维度及squeeze、unsqueeze