Pytorch中的permute函数和transpose,contiguous,view函数的关联

一、前言

在进行深度学习的过程中,经常遇到permute函数,transpose函数,view函数,contiguous函数等,他们起什么作用,之间又有什么联系呢?

二、主要内容

2.1、permute函数和transpose函数

Tensor.permute(a,b,c,d, …):可以对任意高维矩阵进行转置。例子见下:

In[1]: torch.randn(2,3,4,5).permute(3,2,0,1).shape

Out[1]:torch.Size([5, 4, 2, 3])

torch.transpose(Tensor, a,b):只能操作2D矩阵的转置,这是相比于permute的一个不同点;此外,由格式我们可以看出,transpose函数比permute函数多了种调用方式,即torch.transpose(Tensor, a,b)。但是,transpose函数可以通过多次变换达到permute函数的效果。具体见下:

#两种调用方式:
In[1]: t1 = torch.randint(1,10,(2,3,4,5))
	   shape1 = torch.transpose(t1,1,0).shape
	   shape2 = t1.transpose(1,0).shape
	   shape1,shape2
	   
Out[1]:(torch.Size([3, 2, 4, 5]), torch.Size([3, 2, 4, 5]))

#类似permute的效果
In[2]: shape3 = t1.transpose(3,0).transpose(2,1).transpose(3,2).shape
	   shape3
	   
Out[2]:torch.Size([5, 4, 2, 3])	   

2.2 permute函数和view函数

两个函数都是改变tensor的维度,但是区别在于__,具体如下:

#初始化
In[1]: a = torch.randint(1,10,(1,2,3))
 		a_size = a.size()
 		a,a_size
 		
Out[1]:(tensor([[[7, 4, 5],
         		 [9, 5, 6]]]), torch.Size([1, 2, 3]))
         		 
#permute
In[2]:per = a.permute(2,0,1)
	  per_size = per.size()
	   per,per_size
	   
Out[2]:(tensor([[[7, 9]],

        		[[4, 5]],

       			[[5, 6]]]), torch.Size([3, 1, 2]))

#view
In[3]: view = a.view(3,1,2)
	   view_size = view.size()
	   view,view_size

Out[3]:(tensor([[[7, 4]],

        		[[5, 9]],

        		[[5, 6]]]), torch.Size([3, 1, 2]))

#diff
#相信细心的小伙伴已经从两个的output看出来区别了
#具体原因就是在调用permute函数后,数据不再连续,即contiguous,可以继续看3.2

2.3、contiguous函数

contiguous函数起什么作用呢?
当我们在使用transpose或者permute函数之后,tensor数据将会变的不在连续,而此时,如果我们采用view函数等需要tensor数据联系的函数时,将会抛出以下错误:

Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: invalid argument 2: view size is not compatible 
with input tensor's size and stride (at least one dimension spans 
across two contiguous subspaces). Call .contiguous() before 
.view(). at ..\aten\src\TH/generic/THTensor.cpp:203

如果这是使用了contiguous函数,将会解决此错误。

transposepermute操作虽然没有修改底层一维数组,但是新建了一份Tensor元信息,并在新的元信息中的 重新指定 stride。view 方法约定了不修改数组本身,只是使用新的形状查看数据,仅在底层数组上使用指定的形状进行变形。示例如下:

#初始化
In[1]:t = torch.arange(12).reshape(3,4)
	  t,t.stride()

Out[1]:tensor([[ 0,  1,  2,  3],
        	   [ 4,  5,  6,  7],
        	   [ 8,  9, 10, 11]]) ,  (4, 1)
#transpose
In[2]:t2 = t.transpose(0,1)
      t2,t2.stride()
  
Out[2]:tensor([[ 0,  4,  8],
        	   [ 1,  5,  9],
        	   [ 2,  6, 10],
        	   [ 3,  7, 11]]) ,  (1, 4)
#对比验证
In[3]:t.data_ptr() == t2.data_ptr() # 底层数据是同一个一维数组
Out[3]:True

In[4]:t.is_contiguous(),t2.is_contiguous() # t连续,t2不连续
Out[4]:(True, False)

#即t和t2引用同一份底层数据,如下:
#[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]


#contiguous
In[5]:t3 = t2.contiguous()
      t3

Out[5]:tensor([[ 0,  4,  8],
        	   [ 1,  5,  9],
           	   [ 2,  6, 10],
        	   [ 3,  7, 11]])
        	   
In[6]:t3.data_ptr() == t2.data_ptr() # 底层数据不是同一个一维数组

Out[6]:False


#可以发现 t与t2 底层数据指针一致,t3 与 t2 底层数据指针不一致,说明确实重新开辟了内存空间。

三、结尾

上述只是简单介绍了下其功能和异同,具体原理没有深挖,对于想进一步了解contiguous函数的可以移步下方的参考。

ps:
参考—PyTorch中的contiguous

  • 11
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值