参考:
permute函数
python pytorch permute函数
[Python] 维度交换函数:transpose(m,n,r)和permute(m,n,r)
import torch
from torch import nn
import numpy as np
b=np.array([[[ 0, 1, 2,3],
[ 4,5,6,7],
[ 8,9, 10,11]],
[[12, 13, 14, 15],
[16, 17,18, 19],
[20, 21,22, 23]]])
print(b)
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]]
unpermuted=torch.tensor(b)
permuted = unpermuted.permute(2,1,0)
print(permuted)
tensor([[[ 0, 12],
[ 4, 16],
[ 8, 20]],
[[ 1, 13],
[ 5, 17],
[ 9, 21]],
[[ 2, 14],
[ 6, 18],
[10, 22]],
[[ 3, 15],
[ 7, 19],
[11, 23]]], dtype=torch.int32)