argmax()函数的axis参数特别迷,很奇怪,而且往往就算花了大功夫搞明白之后也还是似懂非懂
下面简要介绍下这个函数
官方介绍
argmax(a, axis=None, out=None)
Returns the indices of the maximum values along an axis.
Parameters
----------
a : array_like
Input array.
axis : int, optional
By default, the index is into the flattened array, otherwise
along the specified axis.
out : array, optional
If provided, the result will be inserted into this array. It should
be of the appropriate shape and dtype.
返回沿轴的最大值的索引。
axis:整数,可选
默认情况下,索引位于扁平数组中,否则沿指定轴执行。
自我理解
官方介绍等于没说,因为大多数新手还是看不明白。我是这样理解的:
我们把矩阵想象成空间中的一个长方体(当然4维及以上同理是超立方体)
以三维为例,矩阵mat = [[[3],[4]],[[2],[5]],[[1],[6]]],是一个3*2*1的立方体:
或者说,它的形状是3页、2行、1列
那么argmax(mat, 0)就表示,沿着第0个维度,也就是“3页”那个维度,查找最大值的索引:
这时候返回的应该是一个2行1列的矩阵,也这个矩阵的形状跟沿着这个维度排列的各个单元的形状相同:
当然代码方面也得到了验证:
>>> mat = np.array([[[3],[4]],[[2],[5]],[[1],[6]]])
>>> mat.shape
(3, 2, 1)
>>> np.argmax(mat, 0)
array([[0],
[2]], dtype=int64)
那argmax(mat, 1)呢?根据我们刚才的理解,应该是沿着“2行”这个维度进行查找,然后因为有3个可以作用的单元([[3], [4]]就是三维矩阵中的一个单元,[[2],[5]],[[1],[6]]都是,总共有3个)所以最后返回的应该是3个结果矩阵组成的一个大矩阵:
代码方面验证:
>>> np.argmax(mat, 1)
array([[1],
[1],
[1]], dtype=int64)
那么最后,如果是argmax(mat, 2)呢?代码方面的验证结果是:
np.argmax(a, 2)
array([[0, 0],
[0, 0],
[0, 0]], dtype=int64)
你能想明白吗?