Numpy中numpy.rollaxis函数的理解
在学习Numpy的过程中卡在了这个古怪的函数好一阵子不明其意,研究一番后终于有一些醒悟,把理解贴出来以后备用。
要想理解这个函数首先要理解在Numpy中是如何输出一个多维数组的。Numpy在这里把维数说成“轴”,从0轴到n轴递增。其实这就跟直角坐标系中的xyz轴是一个道理,只不过它只有3个轴罢了。
现在看看numpy.rollaxis函数的官方解释:
numpy.rollaxis(a, axis,start=0)
Roll the specified axis backwards, until it lies in a given positions.
Parameters:
a [ndarray] Input array.
axis [int] The axis to roll backwards.The positions of the other axes donot change relative to one another.
start [int, optional] The axis is rolled until it lies before this position.The default,0, results in a “complete” roll.
其中有两个比较重要的点:
- 把特定的轴向后滚动,直到它到达指定的位置;
- 滚动后,其他轴与其他轴之间的相对位置不变。
知道这些后,我们可以自己来实验一下:
import numpy as np
# 创建了三维的 ndarray
a = np.arange(8).reshape(2, 2, 2)
print(a)