tf.argmax(input, axis=None, name=None, dimension=None)
返回指定维度上,最大值所在位置的下标(索引)
axis=0表示沿着这一列,返回最大值所在的位置
axis=1表示沿着这一行,返回最大值所在的位置
例子:
axis=0
a = tf.constant([1.,2.,3.,4.,5.,])
with tf.Session() as sess:
result=sess.run(tf.argmax(a, 0))
print(result) # 4
b = tf.constant([[0,1,2],[2,1,0],[1,2,0]])
with tf.Session() as sess:
result=sess.run(tf.argmax(b, 0))
print(result) #[1 2 0]
axis=1
b = tf.constant([[0,1,2],[2,1,0],[1,2,0]])
with tf.Session() as sess:
result=sess.run(tf.argmax(b, 1))
print(result) #[2 0 1]