测试代码如下:
import tensorflow as tf
y = tf.Variable(tf.random_normal(shape=[2, 3], mean=0, stddev=1), name='y')
y_ = tf.argmax(y, 1)
with tf.Session() as sess:
init_op = tf.global_variables_initializer()
sess.run(init_op)
print(sess.run(y))
print(sess.run(y_)) # y是一个2*3的矩阵,所以返回的是一个向量
打印结果:
[[2.1094291 0.15722927 1.7590628 ] # 第一行最大元素2.1094291,索引为0
[0.42864147 0.9278875 0.8963362 ]] # 第二行最大元素0.9278875,索引为1
[0 1] # 所以结果为[0, 1]
可以看出,tf.argmax的作用是返回每一行(如果axis=1的话,列类似)最大元素的索引。