numpy中数组有数轴axis的说法,在numpy中提供的一些方法中经常需要这个参数。
其实numpy中的axis和维度是一一对应关系。
比如:
arr1 = np.arange(24).reshape((2, 3, 4))
#其中axis=0对应第一个维度2,axis=1对应第二个维度3,axis=2对应第三个维度4
#然后再哪个数轴上进行运算时,就相当于在对应维度的变化方向上做运算。
三维数组维度为(块, 行, 列)
二维数组维度为(行, 列)
例子1:
arr1 = np.arange(24).reshape((2, 3, 4))
##当axis=0时
print("arr1:\n", arr1)
t1 = arr1.max(axis=0)
print("arr1.max(axis=0):\n", t1)
输出:
arr:
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]]
arr.max(axis=0):
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]
分析:在axis=0(块)的变化方向上运算,得到去掉第一维度的数组为(3, 4)。然后第一块(3, 4)和第二块(3, 4)对应位置比较,最终得到数组t1,维度为(3, 4)
例子2:
##当axis=1时
print("arr1:\n", arr1)
t2 = arr1.max(axis=1)
print("arr1.max(axis=1):\n", t2)
输出:
arr:
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]]
arr.max(axis=1):
[[ 8 9 10 11]
[20 21 22 23]]
分析:在axis=1(行)的变化方向上运算,得到去掉第二维度的数组为(2, 4)。然后每块在第二维度的变化方向上计算最大值,得到最终数组t2,维度为(2, 4)
例子3:
##当axis=2时
print("arr1:\n", arr1)
t3 = arr.max(axis=2)
print("arr.max(axis=2):\n", t3)
输出:
arr:
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]]
arr.max(axis=2):
[[ 3 7 11]
[15 19 23]]
分析:在axis=3(列)的变化方向上运算,得到去掉第三维度的数组为(2, 3)。然后两快都在第三维度的变化方向上计算最大值,得到最终的数组t3,维度为(2, 3)
参考:大佬的博客
https://blog.csdn.net/qq_29573053/article/details/76998695