在numpy中,有很多的函数都涉及到axis,很多函数根据axis的取值不同,得到的结果也完全不同。这里通过详细的例子来学习下,axis到底是什么,它在numpy中的作用到底如何。
一、函数理解
首先argmax() 这个函数的作用是算出数组中最大值的下标。
举个例子:
a = [3, 1, 2, 4, 6, 1]
maxindex = 0
i = 0
for tmp in a:
if tmp > a[maxindex]:
maxindex = i
i += 1
print(maxindex)
二、参数理解
1.一维数组
import numpy as np
a = np.array([3, 1, 2, 4, 6, 1])
print(np.argmax(a))
当没有指定axis的时候,默认是0.所以最后输出的是4(也就是表示第四维值最大)
2.二维数组
import numpy as np
a = np.array([[1, 5, 4, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]])
print(np.argmax(a, axis=0))
最后输出的是[1 2 2 1]
其中np.argmax(a, axis=0)的含义是a[0][j],a[1][j],a[2][j]中最大值的索引。
首先比较是a[0][0],a[1][0],a[2][0]可以得出最大值得下标为a[1][1]
,所以输出数组的第一个值为1.
然后比较的是a[0][0],a[1][1],a[2][2],可以得出最大值得下标为a[1][2],所以输出数组的第一个值为2. 以此类推,可以得出 最后输出为[1 2 2 1]
import numpy as np
a = np.array([[1, 5, 4, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]])
print(np.argmax(a, axis=1))
其中np.argmax(a, axis=0)的含义是a[i][0],a[i][1],a[i][2],a[i][3]中最大值的索引。
首先比较是a[0][0],a[0][1],a[0][2],a[0][3],可以得出最大值得下标为a[0][1]
,所以输出数组的第一个值为1.
然后比较的是a[0][0],a[1][1],a[2][2],a[3][3],,可以得出最大值得下标为a[1][2],所以输出数组的第二个值为0. 以此类推,可以得出 最后输出为[1 0 2]
3.三维数组
import numpy as np
a = np.array([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],
[
[-1, 5, -5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]
]
])
print(np.argmax(a, axis=0))
np.argmax(a, axis=0)的含义是a[0][j][k],a[1][j][k] (j=0,1,2,k=0,1,2,3)中最大值的索引。从a[0][j][k]开始,a[0][j][k]对应的索引为((0,0,0,0),(0,0,0,0),(0,0,0,0)),拿a[0][j][k]和a[1][j][k]对应项作比较6大于-6,3大于-3,9大于-9,所以更新这几个位置的索引,将((0,0,0,0),(0,0,0,0),(0,0,0,0))更新为((0,0,0,0),(0,1,0,0),(1,0,1,0)).。
补充:
在a为二维数组的时候,axis=-1与axis=1结果一致
当a为三维数组的时候,axis=-1与axis=2结果一致,
可以理解为axis=-1结果与axis=(维度-1)一致