tf.argmax()函数介绍和示例
tf.argmax(input_tensor, axis=0)
释义:返回指定维度最大值的索引
-
input_tensor,输入张量
-
axis,指定维度,默认为0。二维情况下,若为0,则返回列数据最大值索引;若为1,则返回行数据最大值索引
示例:
import tensorflow as tf
X = tf.constant([[1, 2, 3],
[2, 3, 4],
[5, 4, 3],
[8, 7, 2]], dtype=tf.float32)
row_max_index = tf.argmax(X) # 默认 0 维度,列最大值索引
col_max_index = tf.argmax(X, axis=1) # axis=1 或 1,行最大值索引
with tf.Session() as sess:
print('列最大值索引:', sess.run(row_max_index))
print('行最大值索引:', sess.run(col_max_index))
列最大值索引: [3 3 1]
行最大值索引: [2 2 0 0]