目录
参考 PyTorch 高维矩阵转置 Transpose 和 Permute ;这里只讲了 permute和Transpose ,
其它相关的函数可参考:
【1】一文掌握torch.squeeze() 和torch.unsqueeze()的用法
【2】pytorch中x = x.view(x.size(0), -1) 的理解
Tensor.permute(d0,d1,d2,d2)
注:但没有 torch.permute() 这个调用方式, 只能 Tensor.permute()。
我们可以把这个permute的过程这样理解;把tensor中的各维度上的值和编号对应,那么
对于tensor(0,1,2,3),permute(d0,d1,d2,d2)分别把 0和d0、1和d1、2和d2、3和d3的维度值交换,并且
d0,d1,d2,d3就是0,1,2,3的一个排列;
注:这里的维度值表示的是在该维度上有多少个数据;eg.对于二维的tensor(3,4)表示3行4列; 0表示行,1表示列。
import torch
image = torch.randn(1,16 ,256, 256) # torch.randn标准正态分布
image = image.permute(1,0,2,3)
print(image.shape)#torch.Size([16, 1, 256, 256])
torch.Transpose(Tensor, a,b)
连续使用transpose也可实现permute的效果。
transpose只能操作2D矩阵的转置。 0表示行,1表示列。
torch.transpose(Tensor, 1, 0)
# transpose 一次只能有两个参数,否则报错了会!!
e = a.transpose(2,0,1)
print(e)
print(e.size())
# TypeError: transpose() takes 2 positional arguments but 3 were given
# torch.rand 产生均匀分布的随机数 torch.rand(2,3,4,5).transpose(3,0).transpose(2,1).transpose(3,2).shape Out: torch.Size([5, 4, 2, 3])
torch.rand(2,3,4,5).transpose(1,0).transpose(2,1).transpose(3,1).shape # Out: torch.Size([3, 5, 2, 4])