在CNN模型中经常会看到transpose算子,transpose就是对多维数组进行转置操作。
下面我们用numpy里面的transpose函数进行理解学习。
1、二维数组
import numpy as np
a=np.arange(6).reshape(2,3)
print("before transpose:")
print(a.shape)
print(a)
a = a.transpose(1,0)
print("after transpose:")
print(a.shape)
print(a)
可以看到2x3的二维数组变成了3X2的二维数组。
2、三维数组
import numpy as np
a=np.arange(16).reshape(2,2,4)
print("before transpose:")
print(a.shape)
print(a)
a = a.transpose(2,1,0)
print("after transpose:")
print(a.shape)
print(a)
可以看到2x2x4的三维数组变成了4x2x2的三维数组
那是怎么转的呢,这里借助一个坐标系来理解
原始2x2x4表示成如下三维坐标系
a.transpose(2,1,0) 即吧0轴和2轴换一下,1轴不变
这样重新读取就是
0 8 4 12 1 9 5 13 2 10 6 14 3 11 7 15