背景引入:
在神经网络的数据处理部分,常要用到numpy中的transpose()函数,对二维矩阵的转置大家都明白,但是对高维数组array和矩阵的transpose还是值得记录一下的。
代码示例:
>>>import numpy as np
>>>arr1=np.arange(16).reshape(2,2,4)
>>> arr1
array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]]])
>>> arr1.shape#以元组形式输出
(2, 2, 4)
我们这样来理解arr1的形状,它是一个三维的数组:2*2*4,即表示由两个2*4的矩阵构成。就像RGB的3通道,这里是两个通道,每个通道是由2行4列的小矩阵构成。那么arr1所对应的轴或者说维数就是0,1,2,即(2,2,4)对应(0,1,2)。
下面来看3个transpose变换。
arr1.transpose((1,0,2))#将(0,1,2)变为(1,0,2),形状仍为(2,2,4)
array([[[ 0, 1, 2, 3],
[ 8, 9, 10, 11]],
[[ 4, 5, 6, 7],
[12, 13, 14, 15]]])
这个变换是说将第一轴和第二轴互换,第三轴保持不变。具体到元素应该怎么办呢?比如元素0的索引是(0,0,0),将一二轴互换后仍是(0,0,0),所以不动,同理索引为(0,0,x)和(1,1,x)的元素都不用动。也就是[0,1,2,3]和[12,13,14,15]都不用动。但对于元素4,索引为(0,1,0),将一二轴互换为(1,0,0),同理元素8,索引为(1,0,0)变为(0,1,0),即4,和8互换。
>>> arr1.transpose((1,2,0))#将(0,1,2)变为(1,2,0),此时形状变为(2,4,2)
array([[[ 0, 8],
[ 1, 9],
[ 2, 10],
[ 3, 11]],
[[ 4, 12],
[ 5, 13],
[ 6, 14],
[ 7, 15]]])
仿照上面的分析,元素1原来的索引为(0,0,1),由于轴(0,1,2)变为(1,2,0),故(0,0,1)变为(0,1,0).同理可以分析出其他元素的变动。
再看最后一个:
>>> arr1.transpose((2,0,1))#(0,1,2)变为(2,0,1),形状由(2,2,4)变为(4,2,2)
array([[[ 0, 4],
[ 8, 12]],
[[ 1, 5],
[ 9, 13]],
[[ 2, 6],
[10, 14]],
[[ 3, 7],
[11, 15]]])
其实只要理解了轴动,数字跟着动就行了。
参考文献: