tf.argmax( )
函数作用:
计算矩阵每行或每列最大值的索引参数:
tf.argmax(input, axis = None, name = None, dimension = None)
input:输入tensor
axis:0表示按列,1表示按行
name:自定义输出tensor的名称
dimension:和axis功能一样,默认axis取值优先。返回:
行或列最大值的索引,组成的tensor例子:
test = tf.constant([[1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 6]])
test_0 = tf.argmax(test, 0) # 按列比较,返回每列最大元素的索引
test_1 = tf.argmax(test, 1) # 按行比较,返回每行最大元素的索引
with tf.Session() as sess:
print(sess.run(test))
print()
print(sess.run(test_0)) # 输出[3 3 3]
print()
print(sess.run(test_1)) # 输出[2 2 2 2]