TensorFlow tf.argmax()函数
tf.argmax(input, axis=None, name=None, dimension=None)
对矩阵按行或列计算最大值
四个参数:
1.input:输入值
2.axis:可选值0表示按列,1表示按行求最大值
3.name
4.默认使用axis即可
重点说说axis参数的作用
举例说明
test = np.array([[1, 2, 3], [2, 3, 4], [5, 4, 3], [8, 7, 2]])
np.argmax(test, 0) #输出:array([3, 3, 1]
np.argmax(test, 1) #输出:array([2, 2, 0, 0]
解释:
# axis参数为0时:
test[0] = array([1, 2, 3])
test[1] = array([2, 3, 4])
test[2] = array([5, 4, 3])
test[3] = array([8, 7, 2])
# output : [3, 3, 1]
此时输出的是每一列最大值所在的数组下标。输出的数组元素数量是原矩阵的列数
# axis参数为1时:
test[0] = array([1, 2, 3]) #2
test[1] = array([2, 3, 4]) #2
test[2] = array([5, 4, 3]) #0
test[3] = array([8, 7, 2]) #0
# output : [2, 2, 0, 0]
此时输出的每一个数组中最大值所在的列号。输出的数组元素个数是原数组的数量,即原矩阵行数。
通过比较,我们可以看到,axis两个参数的区别是:0是每个数组对应位置之间的比较,而1则是数组内部元素之间的比较。