argmax(),argmin()是numpy模块中的函数:
直接举例说明:
import numpy as np
y = np.array([[0.1,0.98,0.69],[0.45,0.78,0.99]]) # 随便定义一个二维数组
y
输出:
array([[0.1 , 0.98, 0.69],
[0.45, 0.78, 0.99]])
axis = 1,表示对行操作,argmax是取一行中最大值对应的下标
a = y.argmax(axis = 1)
a
输出:
array([1, 2], dtype=int64)
axis = 0,表示对列操作,argmax是取一列中最大值对应的下标
a = y.argmax(axis = 1)
a
输出:
array([1, 0, 1], dtype=int64)
argmax()函数可以用在多分类任务的输出阶段,对每个样本输出的预测值直接取最大值对应的下标,很是方便
至于argmin()函数,是取相应的最下元素,这里就不多说!