Pytorch基础 - 1. torch.squeeze() 和 unsqueeze()

tensor升维和降维是神经网络的基本操作,比如不同维feature融合等都需要改操作。常用的函数有torch.unsqueeze() 和 torch.unsqueeze()操作。

目录

1. tensor降维操作: torch.squeeze() 和 指定index 

2. tensor升维操作: torch.unsqueeze() 和 使用None

 3. torch.squeeze和torch.unsqueeze的另一种写法


1. tensor降维操作: torch.squeeze() 和 指定index 

(1) 使用torch.squeeze(input,dim),默认删除tensor中所有维度为1的维度,也可指定dim。torch.squeeze — PyTorch 1.13 documentation

import torch

if __name__ == '__main__':
    a = torch.randn((2, 1, 3, 1, 4))
    a1 = torch.squeeze(a)
    print(a1.shape)  # torch.Size([2, 3, 4])
    a2 = torch.squeeze(a, dim=1)
    print(a2.shape)  # torch.Size([2, 3, 1, 4])
    a3 = torch.squeeze(a, dim=3)
    print(a3.shape)  # torch.Size([2, 1, 3, 4])

(2) 也可使用index=0直接指定,使用torch.equal比较两者相等。

if __name__ == '__main__':
    a = torch.randn((2, 1, 3, 1, 4))
    a1 = torch.squeeze(a)
    print(a1.shape)  # torch.Size([2, 3, 4])

    a2 = a[:, 0, :, 0]
    print(a2.shape)  # torch.Size([2, 3, 4])

    print(torch.equal(a1, a2))  # True

2. tensor升维操作: torch.unsqueeze() 和 使用None

(1) torch.unsqueeze(input, dim) ,对指定的dim,执行升维操作,具体可参考官方文档以及如下示例。torch.unsqueeze — PyTorch 1.13 documentation

import torch

if __name__ == '__main__':
    a = torch.randn((2, 3, 4))
    a1 = torch.unsqueeze(a, dim=1)
    print(a1.shape)  # torch.Size([2, 1, 3, 4])
    a2 = torch.unsqueeze(a, dim=2)
    print(a2.shape)  # torch.Size([2, 3, 1, 4])

(2) 简单用法:使用None,使用None来增加新维度

import torch

if __name__ == '__main__':
    a = torch.randn((2, 3, 4))
    a1 = a[:, None, ...]
    print(a1.shape)  # torch.Size([2, 1, 3, 4])
    a2 = a[..., None, :]
    print(a2.shape)  # torch.Size([2, 3, 1, 4])

注意:a1中None后面的三个点可以省略,如下

import torch

if __name__ == '__main__':
    a = torch.randn((2, 3, 4))

    a1_old = a[:, None, ...]
    print(a1_old .shape)  # torch.Size([2, 1, 3, 4])
    a1_new = a[:, None]
    print(a1_new .shape)  # torch.Size([2, 1, 3, 4])

    print(torch.equal(a1_old, a1_new))  # True

 3. torch.squeeze和torch.unsqueeze的另一种写法

一般情况下使用torch.squeeze(x, dim=?)来进行降维,当然还可以直接使用 x.squeeze(dim=?)。

import torch

if __name__ == '__main__':
    a = torch.randn((2, 3, 4))
    a1 = torch.unsqueeze(a, dim=0)
    print(a1.shape)  # torch.Size([1, 2, 3, 4])
    # 另一种写法
    a2 = a.unsqueeze(dim=0)
    print(a2.shape)  # torch.Size([1, 2, 3, 4])
  • 6
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值