tf.argmax()函数格式:
tf.argmax(input, axis=None, name=None, dimension=None)
- input:输入值
- axis:0表示按列计算每列最大数的下标,1表示按行计算每行最大数的下标
- name:名称
- dimension:和axis功能一样,默认axis取值优先
import tensorflow as tf a = tf.get_variable(name='a', shape=[3, 4], dtype=tf.float32, initializer=tf.random_uniform_initializer(minval=-1, maxval=1)) b = tf.argmax(input=a, axis=0) c = tf.argmax(input=a, dimension=1) # 此处用dimesion或用axis是一样的 sess = tf.InteractiveSession() sess.run(tf.initialize_all_variables()) print(sess.run(a)) #[[ 0.43851089 0.29481053 0.39026642 0.7909162 ] #[-0.86328483 -0.97983766 -0.33444929 0.71610451] #[ 0.6196444 -0.5574162 0.00167346 0.14976239]] print(sess.run(b)) # [2 0 0 0] print(sess.run(c))
# [3 3 0]