numpy中的transpose()函数用于数组转置变换,对于二维的好理解,但是对于高维的就很容易出现混乱,有必要整理出比较直观的转换方法。
先来看numpy.transpose()的函数定义:
def transpose(a, axes=None):
"""
Permute the dimensions of an array.
Parameters
----------
a : array_like
Input array.
axes : list of ints, optional
By default, reverse the dimensions, otherwise permute the axes
according to the values given.
两个参数,一个输入数组,一个代表轴转换方式的列表。
为便于理解从二维数组的转置引申到高维数组,二维数据的转置例子很简单,行和列调换,只有一种情况。
import numpy as np
arr=np.arange(6).reshape(2,3)
arr_T=arr.transpose()
print(arr)
print(arr_T)
输出如下:
[[0 1 2]
[3 4 5]]
[[0 3]
[1 4]
[2 5]]
线性代数里面学过,若B是A的转置矩阵,则B与A中的元素满足:b(j,i)=a(i,j)。那么我们可以根据这个关系求一个矩阵的转置矩阵,即交换原矩阵元素的下标即可找到该元素在转置矩阵中的位置。
下面直接应用这个方法对三维矩阵进行转置。定义一个三维矩阵如下:
A=[[[ 0 1 2]
[ 3 4 5]]
[[ 6 7 8]
[ 9 10 11]]]
A中各元素的下标为:
0(0,0,0) 1(0,0,1) 2(0,0,2)
3(0,1,0) 4(0,1,1) 5(0,1,2)
6(1,0,0) 7(1,0,1) 8(1,0,2)
9(1,1,0) 10(1,1,1) 11(1,1,2)
numpy.transpose()函数会对原数组按给定的axis的排列顺序进行转换,三维数组原始的axis的排列顺序为(0,1,2),numpy.transpose()的axis的默认参数为原数组的axis排列的倒序,即(2,1,0)。按字面意思是将原数组的2轴和0轴进行调换,数组元素的角度看就是将第2个下标和第0各下标进行调换。
A中各元素的下标为:
0(0,0,0) 1(0,0,1) 2(0,0,2)
3(0,1,0) 4(0,1,1) 5(0,1,2)
6(1,0,0) 7(1,0,1) 8(1,0,2)
9(1,1,0) 10(1,1,1) 11(1,1,2)
将A数组中各元素的第2个下标和第0个下标互换位置得到新数组B中各元素的下标为:
0(0,0,0) 1(1,0,0) 2(2,0,0)
3(0,1,0) 4(1,1,0) 5(2,1,0)
6(0,0,1) 7(1,0,1) 8(2,0,1)
9(0,1,1) 10(1,1,1) 11(2,1,1)
根据新下标将元素写到对应位置即可得到B数组:
B=[[[ 0 6]
[ 3 9]]
[[ 1 7]
[ 4 10]]
[[ 2 8]
[ 5 11]]]
用程序验证一下:
>>> import numpy as np
>>> arr=np.arange(12).reshape(2,2,3)
>>> print(arr)
[[[ 0 1 2]
[ 3 4 5]]
[[ 6 7 8]
[ 9 10 11]]]
>>> print(arr.transpose())
[[[ 0 6]
[ 3 9]]
[[ 1 7]
[ 4 10]]
[[ 2 8]
[ 5 11]]]
>>>
我们再看给定axis=(1,2,0)的情况,这个转换顺序可分解为(0,1,2)->(2,1,0)->(1,2,0),动手移动原数组下标结果如下:
原数组各元素的下标为:
0(0,0,0) 1(0,0,1) 2(0,0,2)
3(0,1,0) 4(0,1,1) 5(0,1,2)
6(1,0,0) 7(1,0,1) 8(1,0,2)
9(1,1,0) 10(1,1,1) 11(1,1,2)
新数组各元素的下标为:
0(0,0,0) 1(0,1,0) 2(0,2,0)
3(1,0,0) 4(1,1,0) 5(1,2,0)
6(0,0,1) 7(0,1,1) 8(0,2,1)
9(1,0,1) 10(1,1,1) 11(1,2,1)
新数组=[[[ 0 6]
[ 1 7]
[ 2 8]]
[[ 3 9]
[ 4 10]
[ 5 11]]]
程序验证一下:
>>> import numpy as np
>>> arr=np.arange(12).reshape(2,2,3)
>>> print(arr)
[[[ 0 1 2]
[ 3 4 5]]
[[ 6 7 8]
[ 9 10 11]]]
>>> print(arr.transpose())
[[[ 0 6]
[ 3 9]]
[[ 1 7]
[ 4 10]]
[[ 2 8]
[ 5 11]]]
>>> print(arr.transpose((1,2,0)))
[[[ 0 6]
[ 1 7]
[ 2 8]]
[[ 3 9]
[ 4 10]
[ 5 11]]]
>>>
发现与我们手动移下标的结果无误!