1 先看看官方中英文doc:
torch.Tensor.permute (Python method, in torch.Tensor)
1.1 permute(dims)
将tensor的维度换位。
参数: - dims (int …*) - 换位顺序
例:
>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(2, 0, 1).size()
torch.Size([5, 2, 3])
permute函数功能还是比较简单的,下面主要介绍几个细节点:
2.1 transpose与permute的异同
Tensor.permute(a,b,c,d, …):permute函数可以对任意高维矩阵进行转置,但没有 torch.permute() 这个调用方式, 只能 Tensor.permute():
>>> torch.randn(2,3,4,5).permute(3,2,0,1).shape
torch.Size([5, 4, 2, 3])
torch.transpose(Tensor, a,b):transpose只能操作两个维度进行转置,有两种调用方式;
另:连续使用transpose也可实现permute的效果:
>>> torch.randn(2,3,4,5).transpose(3,0).transpose(2,1).transpose(3,2).shape
torch.Size([5, 4, 2, 3])
>>> torch.randn(2,3,4,5).transpose(1,0).transpose(2,1).transpose(3,1).shape
torch.Size([3, 5, 2, 4])
从以上操作中可知,permute相当于可以同时操作于tensor的若干维度,transpose只能同时作用于tensor的两个维度,超过两个维度会报错
torch.randn(2,3,4,5).transpose(3,0,1,2)
Traceback (most recent call last):
File "D:\anaconda\lib\site-packages\IPython\core\interactiveshell.py", line 2963, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-9-324e82f49bd8>", line 1, in <module>
torch.randn(2,3,4,5).transpose(3,0,1,2)
TypeError: transpose() takes 2 positional arguments but 4 were given