Pytorch squeeze() unsqueeze() 用法

简介

torch.squeeze(input, dim=None, out=None):对数据的维度进行压缩,去掉维数为1的的维度。
squeeze函数功能:去除size为1的维度,包括行和列。当维度大于等于2时,squeeze()无作用。
squeeze(0):代表若第一维度值为1则去除第一维度,例如 a.squeeze(0),a 为 torch.tensor() 格式张量。
squeeze(1):代表若第二维度值为1则去除第二维度
squeeze(-1):去除最后维度值为1的维度

torch.unsqueeze (input, dim=None, out=None):对数据的维度进行扩容,即升维。

使用格式可以是torch.unsqueeze(x, 0),也可为是x.unsqueeze(0)

实例代码

a = torch.Tensor(1, 3)
print(a)
print(a.squeeze(0))
print(a.squeeze(1))

b = torch.Tensor(2, 3)
print(b)
print(b.squeeze(0))
print(b.squeeze(1))

c = torch.Tensor(3, 1)
print(c)
print(c.squeeze(0))
print(c.squeeze(1))

x = torch.tensor([1, 2, 3, 4])
print(x)
print(torch.unsqueeze(x, 0))
print(torch.unsqueeze(x, 1))

过程解析

定义张量 a,为 2 维,第一维度有 1 个元素,第二维度有 3 个元素。
输出:tensor([[2.6994e-30, 2.4164e-13, 1.8392e-13]])
通过 a.squeeze(0) 对第一维度进行降维,此时第一维度有 1 个元素,可降维,第一维度消失,第二维度自动变成第一维度有三个元素,与 a 相比,即消失了一层 “[]”。
输出:tensor([2.6994e-30, 2.4164e-13, 1.8392e-13])
通过 a.squeeze(1) 对第二维度进行降维,此时第一维度有 3 个元素,不可降维,则不做操作,输出与 a 相同。
输出:tensor([[2.6994e-30, 2.4164e-13, 1.8392e-13]])

定义张量 b,为 2 维,第一维度有 2 个元素,第二维度有 3 个元素。
第一、二维度均不可降维,因为三次输出相同。
输出:

tensor([[0., 0., 0.],
        [0., 0., 0.]])
tensor([[0., 0., 0.],
        [0., 0., 0.]])
tensor([[0., 0., 0.],
        [0., 0., 0.]])

定义张量 c,为 2 维,第一维度有 3 个元素,第二维度有 1 个元素。
输出:

tensor([[0.0000e+00],
        [       nan],
        [5.2781e-24]])

通过 c.squeeze(0) 对第一维度进行降维,此时第一维度有 3 个元素,不可降维,则不做操作,输出与 c 相同。
输出:

tensor([[0.0000e+00],
        [       nan],
        [5.2781e-24]])

通过 c.squeeze(1) 对第二维度进行降维,此时第二维度有 1 个元素,可降维,第二维度消失,第二维度数值自动进入第一维度中。
输出:
tensor([0.0000e+00, nan, 5.2781e-24])

定义张量 x,为 1 维,其中数值依次为 1, 2, 3, 4。
输出:tensor([1, 2, 3, 4])
通过 x.unsqueeze(0) 于第一维度位置增加一个维度,使原张量变成 2 维,维度变为 (1, 4)。与 x 相比,即增加了一层 “[]”。
输出:tensor([[1, 2, 3, 4]])
通过 x.unsqueeze(1) 于第二维度位置增加一个维度,使原张量变成 2 维,维度变为 (4, 1)。
输出:tensor([[1], [2], [3], [4]])

运行结果

tensor([[2.6994e-30, 2.4164e-13, 1.8392e-13]])
tensor([2.6994e-30, 2.4164e-13, 1.8392e-13])
tensor([[2.6994e-30, 2.4164e-13, 1.8392e-13]])
tensor([[0., 0., 0.],
        [0., 0., 0.]])
tensor([[0., 0., 0.],
        [0., 0., 0.]])
tensor([[0., 0., 0.],
        [0., 0., 0.]])
tensor([[0.0000e+00],
        [       nan],
        [5.2781e-24]])
tensor([[0.0000e+00],
        [       nan],
        [5.2781e-24]])
tensor([0.0000e+00,        nan, 5.2781e-24])
tensor([1, 2, 3, 4])
tensor([[1, 2, 3, 4]])
tensor([[1],
        [2],
        [3],
        [4]])
  • 3
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值