函数说明:
tf.argmax( input, axis=None, name=None, dimension=None, output_type=tf.int64 )
参数:
- input: 输入矩阵
- axis: 默认为None,0表示的是按列比较返回最大值的索引,1表示按行比较返回最大值的索引
- name: 默认为None
- dimension: 默认为None
- output_type: 默认类型为int64
用途:
返回最大的那个数值所在的下标(第一个参数是矩阵,第二个参数是0或者1。0表示的是按列比较返回最大值的索引,1表示按行比较返回最大值的索引)。
代码示例:
import tensorflow as tf Vector = [1,1,2,5,3] #定义一个向量 X = [[1,3,2],[2,5,8],[7,5,9]] #定义一个矩阵 with tf.Session() as sess: a = tf.argmax(Vector, 0) b = tf.argmax(X, 0) c = tf.argmax(X, 1) print(sess.run(a)) print(sess.run(b)) print(sess.run(c))
运行结果:
3 [2 1 2] [1 2 2]