P3 PyTorch 维度变换

前言

参考

课时21 维度变换-2_哔哩哔哩_bilibili

目录:

  1.      view
  2.     unsqueeze
  3.      squeeze
  4.     Expand
  5.      repeat
  6.     转置
  7.     contiguous
  8.     Permute
  9.     例子

    

     


一 view 

      作用:

            重新调整Tensor的形状,通过shape,或size属性可以看出来

     1.1 普通的用法

    import torch 

    def test():
    
          img = torch.rand(4,1,28,28)
 
           a1= img.view(4,28*28)
           print("\n a1 ",a1.shape,a1.size())
    
     if __name__ == "__main__":
             test()

    输出:

 a1  torch.Size([4, 784]) torch.Size([4, 784])

   1.2参数-1 (自动调整该维度 size)

    

import torch 

def test():
    
    a = torch.arange(0, 16, 1)
 
    a1= a.view(-1,16)
    a2 = a.view(16,-1)
    print("\n a1 ",a2.shape,a1.shape)
    print(a1.size(0),a2.size(0))
if __name__ == "__main__":
    test()

输出

    a1    torch.Size([16, 1])         torch.Size([1, 16])
   1        16


 二  unsqueeze

      在指定的维度增加一个维度

      2.1  正数

         指定的位置前面插入一个维度

def test():
    
    a = torch.rand(4,5,6)
    
    a2 =a.unsqueeze(2).shape
    
    print("\n a2.shape: ",a2,a.ndim)
    a3 = a.unsqueeze(3).shape
    print("\n a3: ",a3)
 
 
if __name__ == "__main__":
    test()

       输出


     a2.shape  torch.Size([4, 5, 1, 6]) 3

     a3  torch.Size([4, 5, 6, 1])

  2.2 负数

       在指定的维度之后插入一个维度

       a = torch.rand(4,5,6)

      a2 =a.unsqueeze(-3).shape

      在指定的维度之后插入一个维度

     
      a2.shape:  torch.Size([4, 1, 5, 6]) 3 

例1: 在指定的维度之后插入

 例2 :   在指定的维度之前插入


三  squeeze

      维度挤压,如果指定的维度为1,则删除该维度,其它则保持不变

    a = torch.rand(1,3,1,1)
    
    b = a.squeeze().shape
    print("\n  默认 ",b)
    
   
    

    #索引为正数
    #Positive and negative number
    p_0 = a.squeeze(0).shape  
    print("\n 挤压0",p_0)
    
    p_1 = a.squeeze(1).shape
    print("\n 挤压1",p_1)
    
    #索引为负数
    n_0 = a.squeeze(-1).shape  
    print("\n 挤压-1 ",n_0)
    
    n_1 = a.squeeze(-4).shape
    print("\n 挤压-4 ",p_1)


四 Expand

   正数作用和reshape 一样,对应的维度上面调整到指定的大小

  负数 -1:

      表示该维度保持不变


五 repeat

    在指定的维度上面复制几次

   


六 矩阵转置

     下面这张方法值只适用于2D的矩阵。

    a = torch.rand(2,3)
    d= a.T.shape
    
    print("\n b ",d)


七  contiguous

   torch.contiguous()方法首先拷贝了一份张量在内存中的地址,然后将地址按照形状改变后的张量的语义进行排列。

torch.contiguous()方法语义上是“连续的”,经常与torch.permute()、torch.transpose()、torch.view()方法一起使用,要理解这样使用的缘由,得从pytorch多维数组的底层存储开始说起:

touch.view()方法对张量改变“形状”其实并没有改变张量在内存中真正的形状,可以理解为:

view方法没有拷贝新的张量,没有开辟新内存,与原张量共享内存;
view方法只是重新定义了访问张量的规则,使得取出的张量按照我们希望的形状展现。

   7.1 pytorch与numpy在存储MxN的数组时,均是按照行优先将数组拉伸至一维存储

        

   a = torch.tensor([[1,2,3],
                     [4,5,6]])
    
    print(a,a.shape)

   在内存中的样子:相当于做个flatten

  

[1, 2, 3, 4, 5, 6]

7.2 当我们使用torch.transpose()方法或者torch.permute()方法对张量翻转后,改变了张量内存的形状

   

    a = torch.tensor([[1,2,3],
                     [4,5,6]])
    
    print("\n a ",a.shape)
    
    a2 =a.transpose(0,1)

   此刻如果再想通过view 方式访问就会出错


原因是:改变了形状的a2语义上是2行3列的,在内存中还是跟a一样,没有改变,导致如果按照语义的形状进行view拉伸,数字不连续,此时torch.contiguous()方法就派上用场了
 

    a = torch.tensor([[1,2,3],
                  [4,5,6]])
    
    
    b = a.transpose(0,1).contiguous().view(2,3)
 
    print(b)
   
   out:
   tensor([[1, 4, 2],
        [5, 3, 6]])

八 Permute

      transpose 改变维度只能两两交换,有的时候需要多次交换比较繁琐

 比如[B,C,H,W]  需要更变成[B,H,W,C]

则 [B,C,H,W]--->[B,W,H,C]--->[B,H,W,C] 两次变更才能得到

   通过permute 操作一次就可以完成

  a = torch.rand(2,3,4,5)
    
  b = a.permute(0,2,3,1).shape
 torch.Size([2, 4, 5, 3])
    

   如下为shape的变换情况

 


 九  例子

     8.1 view

把原先tensor中的数据按照行优先的顺序排成一个一维的数据(这里应该是因为要求地址是连续存储的),然后按照参数组合成其他维度的tensor

 8.2 view+tranpose+continues

 a = torch.rand(5,3,16,16)
    
 a1 = a.transpose(1,3).view(5,3*16*16).view(5,3,16,16)

   5张图片,图片channel =3 ,height=16, width =16

通过transpose改变了meomery中的存储结构,不再是那种按行顺序结构了,再去访问的时候

就会出错

  解决方法:

    a1 = a.transpose(1,3).contiguous().view(5,3*16*16).view(5,3,16,16)
    
    eq = torch.all(torch.eq(a,a1))

    问题:

     不报错了,但是因为transpose改变了数据Memory中的顺序,再访问依然按照

 [B,W,H,C] 去获取[B,C,H,W]数据.

    解决方案:

    [B,C,H,W]->[B,W,H,C]->[B,C,H,W]

 a1 = a.transpose(1,3).contiguous().view(5,3*16*16).view(5,16,16,3).transpose(1,3)
    
    eq = torch.all(torch.eq(a,a1))
    
    print(eq)
    

    tensor(True)

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值