Pytorch view和permute的用法
一. view的用法
首先把原先tensor中的数据按照行优先的顺序排成一个一维的数据(这里应该是因为要求地址是连续存储的),然后按照参数组合成其他维度的tensor。
比如
import torch
import numpy as np
a =np.array([[1,2,3],[4,5,6]])
a = torch.tensor(a)
print(a.size())
构造了一个2*3的tensor
torch.Size([2, 3])
tensor([[1, 2, 3],
[4, 5, 6]], dtype=torch.int32)
但是我们想把这个tensor如下变成3*2的张量,就可以使用view函数
print(a.view(3,2))
先按照行排列,然后才转换成想要的形状
tensor([[1, 2],
[3, 4],
[5, 6]], dtype=torch.int32)
在这里view有一个特殊的参数,那就是-1,view(-1)是什么意思呢?
我们先看一下输出结果,结果变成了一行数据,也就是说不管原来是什么维度的张量,经过view操作之后,行优先的顺序变成了一行数据
print(a.view(-1))
tensor([1, 2, 3, 4, 5, 6], dtype=torch.int32)
-1在这里的意思是让电脑帮我们计算,比如下面的例子,总长度是20,我们不想自己算20/5=4,就可以在不想算的位置放上-1,电脑就会自己计算对应的数字,这个在实际搭建网络的时候是很好用的。
import torch
a = torch.arange(0,20) #此时a的shape是(1,20)
a.view(4,5).shape #输出为(4,5)
a.view(-1,5).shape #输出为(4,5)
a.view(4,-1).shape #输出为(4,5)
二. permute的用法
permute(dims)
将tensor的维度换位。
参数:参数是一系列的整数,代表原来张量的维度。比如三维就有0,1,2这些dimension。
再比如图片img的size比如是(28,28,3)就可以利用img.permute(2,0,1)得到一个size为(3,28,28)的tensor。就是不改变每一行数据的基本内容,并没有进行重新排列。
import torch
import numpy as np
a =np.array([[1,2,3],[4,5,6]])
a = torch.tensor(a)
print(a.size()) #输出为(2,3)
print(a.permute(1,0)) #输出为(3,2)
tensor([[1, 4],
[2, 5],
[3, 6]], dtype=torch.int32)