捋清pytorch的transpose、permute、view、reshape、contiguous

transpose和permute

都是进行转置操作,但是有些许区别,permute可以完全替代transpose,transpose不能替代permute。

transpose的基本操作

接收两个维度dim1和dim2,将dim1和dim2调换:

In [82]: a=torch.Tensor([[1,2,3],[4,5,6]])
In [85]: a                                                                                                                  
Out[85]: 
tensor([[1, 2, 3],
        [4, 5, 6]], dtype=torch.int32)

In [87]: a.transpose(0,1)                                                                                                   
Out[87]: 
tensor([[1, 4],
        [2, 5],
        [3, 6]], dtype=torch.int32)

# dim1和dim2的顺序不影响结果。
In [88]: a.transpose(1,0)                                                                                                   
Out[88]: 
tensor([[1, 4],
        [2, 5],
        [3, 6]], dtype=torch.int32)

# 可以用torch.transpose(tensor, dim1, dim2)调用,
# 也可以直接用tensor.transpose(dim1, dim2)调用
In [89]: torch.transpose(a, 1 , 0)                                                                                          
Out[89]: 
tensor([[1, 4],
        [2, 5],
        [3, 6]], dtype=torch.int32)

permute的基本操作

重组tensor维度,支持高维操作,tensor.permute(dim0, dim1, … dimn),表示原本的dim0放在第0维度,dim1放在第1维度,…, dimn放在第n维度,必须将所有维度写上。可以想象如果tensor.permute(0,1,2,3,…,n)相当于没有操作。

In [90]: a.permute(1,0)                                                                                                     
Out[90]: 
tensor([[1, 4],
        [2, 5],
        [3, 6]], dtype=torch.int32)

In [91]: a.permute(0,1)                                                                                                     
Out[91]: 
tensor([[1, 2, 3],
        [4, 5, 6]], dtype=torch.int32)

# 不支持用torch.permute调用
In [92]: torch.permute                                                                                                      
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-92-be81eddf2bd2> in <module>
----> 1 torch.permute

AttributeError: module 'torch' has no attribute 'permute'

差别

transpose一次只能调换两个维度,而permute可以调换多个维度,在二维情况下,两者可以互相替换,但是在三维及以上时,permute就比transpose强大得多。比如有个三维tensor,想要将第零维放到第一维,第一维放到第二维,第二维放回第零维,用permute一句就实现了,用transpose要转换两次。

In [93]: b = torch.rand(2,3,4)                                                                                              

In [94]: b.permute(2,0,1)                                                                                                   
Out[94]: 
tensor([[[0.3574, 0.5199, 0.0402],
         [0.3090, 0.5155, 0.5344]],

        [[0.8814, 0.7096, 0.1845],
         [0.6959, 0.0752, 0.5228]],

        [[0.1947, 0.5147, 0.4137],
         [0.4682, 0.4123, 0.1977]],

        [[0.9947, 0.9643, 0.7728],
         [0.7088, 0.6423, 0.3669]]])

In [95]: b.transpose(0,1).transpose(0,2)                                                                                    
Out[95]: 
tensor([[[0.3574, 0.5199, 0.0402],
         [0.3090, 0.5155, 0.5344]],

        [[0.8814, 0.7096, 0.1845],
         [0.6959, 0.0752, 0.5228]],

        [[0.1947, 0.5147, 0.4137],
         [0.4682, 0.4123, 0.1977]],

        [[0.9947, 0.9643, 0.7728],
         [0.7088, 0.6423, 0.3669]]])

contiguous

contiguous意为连续,contiguous()方法将tensor转为连续的并返回一个新的tensor,如果本身就是连续的,那么不进行任何操作,那么何为连续的tensor?

pytorch在新建任何尺寸tensor时,总会以一维数组的形式去存储,同时建立配套的元信息,保存了tensor的形状,在访问tensor时,将多维索引转化成一维数组相对于数组起始位置的偏移,即可找到对应的数据。那么这个一维数组是什么样的?它是将我们建立的tensor按照行优先的原则进行展开,比如一个torch.Tensor([[1,2,3],[4,5,6]])它是以[1,2,3,4,5,6]的形式保存在内存中。

一个新建的tensor,一定是连续的,行相邻的两个元素,在内存中也一定相邻。但是如果经过transpose或者permute等操作,行相邻的两个元素,在内存上不相邻了,此为不连续。注意transpose和permute并不会改变tensor的底层的一维数组,只是会改变元信息。

那么如何让经过了transpose或者permute的tensor重新变成连续?就是调用contiguous方法,它会重新生成一个tensor,新tensor底层的一维数组和原来的不一样,它是将当前tensor按行展开进行保存的。

举个例子,下面的is_contiguous()是判断tensor是否连续,data_ptr是返回tensor的数据指针

In [100]: a=torch.Tensor([[1,2,3],[4,5,6]])                                                                                 

In [101]: a                                                                                                                 
Out[101]: 
tensor([[1., 2., 3.],
        [4., 5., 6.]])

In [102]: a.is_contiguous()                                                                                                 
Out[102]: True

# 注意:这里flatten()只是演示a在底层的一维数组的样子
# flatten的结果可不一定是底层一维数组的样子,只有在tensor连续时才刚好一样。
In [103]: a.flatten()                                                                                                       
Out[103]: tensor([1., 2., 3., 4., 5., 6.])

# 转置后不连续了
In [104]: a.transpose(0,1).is_contiguous()                                                                                  
Out[104]: False

# contiguous一下又变连续了
In [107]: a.transpose(0,1).contiguous().is_contiguous()                                                                     
Out[107]: True

# transpose后,还是用的同一份底层数组,
# 但是contiguous是换了一份底层数组,可以说是完全不一样的tensor
In [108]: a.data_ptr()                                                                                                      
Out[108]: 66883392

In [109]: a.transpose(0,1).data_ptr()                                                                                       
Out[109]: 66883392

In [110]: a.transpose(0,1).contiguous().data_ptr()                                                                          
Out[110]: 66774976

view和reshape

view和reshape都能改变tensor尺寸,但是有两点区别:

  1. view是在原来数组上更改,不会开辟新数组,reshape会开辟新数组 (经评论区指正,如果tensor是连续的,那么reshape和view一样也不会开辟新tensor,而是直接修改)
  2. view要求tensor是连续的,不连续会报下方的错,reshape不要求。所以如果tensor曾经进行过转置比如transpose或者permute,则一定要先经过contiguous()转为连续数组,再进行view操作。当然上面也讲过,contiguous会返回一个新数组,所以如果tensor曾转置过,使用view和reshape都会得到一个新tensor。
RuntimeError: view size is not compatible with input tensor's size and stride (at 
least one dimension spans across two contiguous subspaces). Use .reshape(...) 
instead.

报错信息也让用reshape了,不用reshape的话先用contiguous再view也可以,暂时不清楚这两者有什么区别。

参考文献

  1. PyTorch的permute和reshape/view的区别:https://blog.csdn.net/xpy870663266/article/details/101616286
  2. PyTorch中的contiguous:https://zhuanlan.zhihu.com/p/64551412
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值