pytorch中unsqueeze()、unsqueeze_()、squeeze()函数的作用与区别

1、torch.unsqueeze()函数

语法:

torch.unsqueeze(input, dim=None)  

参数:
input----输入张量
dim------在这个维度上插入一个新的维度

作用:
在指定位置插入一个新的维度

注意:
dim的范围是 [ -input.dim() -1, input.dim()+1 ),是一个左闭右开的区间,当dim为负值时,会自动转换为dim = dim+input.dim()+1

示例:

a = torch.ones(2, 3, 2)          # 创建一个全1的三维张量
'''a输出结果:
tensor([[[1., 1.],
         [1., 1.],
         [1., 1.]],
        [[1., 1.],
         [1., 1.],
         [1., 1.]]])
'''
a.size()                       # (2,3,2)
b = torch.unsqueeze(a, dim=1)  # 在dim=1插入一个维度
# b=a.unsqueeze(dim=1)  另一种写法
'''b输出结果:
tensor([[[[1., 1.],
          [1., 1.],
          [1., 1.]]],
        [[[1., 1.],
          [1., 1.],
          [1., 1.]]]])
'''
b.size()                       # (2,1,3,2)

经过unsqueeze(dim=1)函数之后,张量维度由(2,3,2)转换为(2,1,3,2),在dim=1处插入了一个维度。

之前经常看不对张量的size,在这里记录一个方法:

# 以上文的b为例:
tensor([[[[1., 1.],
          [1., 1.],
          [1., 1.]]],
        [[[1., 1.],
          [1., 1.],
          [1., 1.]]]])   # b.size()=(2,1,3,2)

下面对照着中括号来看:
首先,最外边的中括号:无论多少维的张量,最外边都有一层中括号,所以最外边的中括号本身不包含任何维度信息;
第二层: 从垂直方向上看,有两层,说明该维度长度为2;
第三层: 由于上一层有两层,输出也将其分成了上下两块,只需要看第一块即可(后面的维度都是只看第一块),从垂直方向上看,有一层,说明该维度长度为1;
第四层: 从垂直方向上看,有3层,说明该维度长度为3;
第五层: 也是最内层,随便选择一个中括号,观察中括号内有两个元素,说明该维度为2。
所以,b的size为(2,1,3,2)

2、unsqueeze_()函数

unsqueeze_unsqueeze 实现一样的功能,区别在于 unsqueeze_in_place 操作,即 unsqueeze 不会对使用 unsqueeze 的 tensor 进行改变,想要获取 unsqueeze 后的值必须赋予个新值, unsqueeze_ 则会对自己改变。
示例:

a = torch.Tensor([1, 2, 3, 4])
print(a)
# tensor([1., 2., 3., 4.])

b = torch.unsqueeze(a, 1)
print(b)
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])

print(a)
# tensor([1., 2., 3., 4.])

print(a.unsqueeze_(1))
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])

print(a)
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])

3、squeeze()函数

语法:

torch.squeeze(input, dim=None)

参数:
input----输入张量
dim------如果给定,只会在这个维度上对输入张量进行降维

作用:
当未给定dim时,squeeze将输入张量形状中的1 去除并返回。 如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)

a = torch.zeros(1, 2, 1, 3, 1, 4)
b=torch.squeeze(a)  # 未给定dim,squeeze对整个张量进行挤压操作,将张量中的1全部去除
b.size()
#输出: torch.Size([2, 3, 4])

当给定dim时,那么挤压操作只在给定维度上。例如,输入形状为: (A×1×B), squeeze(input, 0) 将会保持张量不变,只有用 squeeze(input, 1),形状会变成 (A×B)。

a = torch.zeros(1, 2, 1, 3, 1, 4)
c = torch.squeeze(a,0)  # 当给定dim时,那么挤压操作只在给定维度上
c.size()
#输出: torch.Size([2, 1, 3, 1, 4])
d = torch.squeeze(a, 1)  # a在dim=1上维度是2,形状不会发生变化
d.size()
#输出: torch.Size([1, 2, 1, 3, 1, 4])

参考:
https://zhuanlan.zhihu.com/p/86763381
PyTorch入门——数组维度及squeeze、unsqueeze

  • 3
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值