tf.argmax(input, axis=None, name=None, dimension=None)
Returns the index with the largest value across axis of a tensor.
input is a Tensor and axis describes which axis of the input Tensor to reduce across. For vectors, use axis = 0.
For your specific case let's use two arrays and demonstrate this
tf.argmax用来返回矩阵中每一行或每一列的最大值的序号;当axis=0时,返回每一列的最大值索引;当axis=1时返回每一行的最大值索引。
例子:
import tensorflow as tf
import numpy as np
A = [[1,3,4], [2,4,22],[10,3,11]]
with tf.Session() as sess:
print(sess.run(tf.argmax(A, 0)))
print("========================")
print(sess.run(tf.argmax(A, 1)))
输出:
[2 1 1]
========================
[2 2 2]
解释:
当axis=0时,输出每一列的最大值。
第一列10最大,序号为2
第二列4最大,序号为1
第三列22最大,序号为1
当axis=1时,输出每行的最大值。
第一行4最大,序号为2
第二行22最大,序号为2
第三行11最大,序号为2