PyTorch中的unsqueeze函数(自用)

前言

最近在学习swin_transformer的模型搭建,其中用到了广播机制,在理解广播机制的过程中发现自己对torch.unsqueeze()函数比较困惑,所以做了个小实验帮助自己理解。

问题阐述

我们都知道,torch.unsqueeze()函数的作用是拓展张量维度,那么在不同位置拓展之后,原数据是怎样排列的呢?下面进入实验部分。

实验

>>> import torch
>>>
>>> a = torch.Tensor([1,2,3,4,5,6,7,8,9])
>>> print(a)
tensor([1., 2., 3., 4., 5., 6., 7., 8., 9.])
>>> b = a.view(3,3)
>>> print(b)
tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]])
>>> c = b.unsqueeze(1)
>>> print(c)
tensor([[[1., 2., 3.]],

        [[4., 5., 6.]],

        [[7., 8., 9.]]])
>>> print(c.size())
torch.Size([3, 1, 3])
>>> d = b.unsqueeze(2)
>>> print(d)
tensor([[[1.],
         [2.],
         [3.]],

        [[4.],
         [5.],
         [6.]],

        [[7.],
         [8.],
         [9.]]])
>>> print(d.size())
torch.Size([3, 3, 1])
>>> e = b.unsqueeze(0)
>>> print(e)
tensor([[[1., 2., 3.],
         [4., 5., 6.],
         [7., 8., 9.]]])
  1. 创建Tensor变量a,此时a.shape = ([9]),这里解释为有9个元素。将a变成3行3列的矩阵b,此时b.shape = ([3,3]),这里解释为有3个通道,每个通道中有3个元素。
  2. c = b.unsqueeze(1),此时c.shape = ([3,1,3]),这里解释为有3个通道,每个通道有1行,每行有3列。
  3. d = b.unsqueeze(2),此时d.shape = ([3,3,1]),这里解释为有3个通道,每个通道有3行,每行有1列。

值得注意的是,上述步骤2,3中元素排列的方式不一样,2中元素水平排列,3中元素竖直排列,因而使用广播机制时需要进行复制的数值也不一样。

总结&拓展

至此,实验结束,对torch.unsqueeze()函数的理解加深不少。下面放上广播机制相关博文供参考。

PyTorch | 广播机制(broadcast)_pytorch broadcast-CSDN博客

 

  • 5
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值