argmax(input, axis=None, name=None, dimension=None, output_type=tf.int64)
Returns the index with the largest value across axes of a tensor. (deprecated arguments)
根据axis的值返回行或者列最大值的下标,axis取值[-2,2)
上代码
#创建一个2*3的数组,使用随机种子,保证数据不变
a = tf.Variable(tf.random_normal([2,3],seed = 12345))
#初始化变量
init = tf.global_variables_initializer()
#对每一列进行计算,返回最大值下标
b = tf.argmax(a,0)
#启动会话层
with tf.Session() as sess:
sess.run(init)
print(sess.run(a))
print(sess.run(b))
输出结果:
原始数据:
[[ 0.88424665 0.07843047 0.13639879]
[-0.6109575 1.8525681 -1.1506747 ]]
返回最大值下标索引
[0 1 0]
它是如何返回 [ 0 1 0]的呢
计算流程:
如图,把原始数据使用两根红线竖着分割成三份(按列分割),从上到下进行对比(按行对比),即:0.088424665和-0.6109575进行对比,返回最大值的下标,0.07843047和1.8525681进行对比,返回最大值的下标,0.13639879和-1.1506747进行对比,返回最大值的下标。
0.088424665(下标0)大于 -0.6109575(下标1) ,返回0
0.07843047(下标0) 小于 1.8525681(下标1) , 返回1
0.13639879(下标0) 大于 -1.1506747(下标1), 返回0
最终得到结果[ 0 1 0]
接下来axis 设置为 1
#创建一个2*3的数组,使用随机种子,保证数据不变
a = tf.Variable(tf.random_normal([2,3],seed = 12345))
#初始化变量
init = tf.global_variables_initializer()
#对每一列进行计算,返回最大值下标
b = tf.argmax(a,1)
#启动会话层
with tf.Session() as sess:
sess.run(init)
print('原始数据:')
print(sess.run(a))
print('返回最大值下标索引')
print(sess.run(b))
输出结果:
原始数据:
[[ 0.88424665 0.07843047 0.13639879]
[-0.6109575 1.8525681 -1.1506747 ]]
返回最大值下标索引
[0 1]
计算流程:
如图,把原始数据使用一根红线竖着分割成两份(按行分割),从左到右进行对比(按列对比),即:0.088424665和0.07843047和0.13639879进行对比,返回最大值的下标,-0.6109575和1.8525681和-1.1506747进行对比,返回最大值的下标。
0.88424665(下标0) 0.07843047(下标1) 0.13639879(下标2) ,经过比较0.88424665最大,返回0
-0.6109575(下标0) 1.8525681(下标1) -1.1506747(下标2) ,经过比较1.8525681最大,返回1
最终得到结果[0 1]
总结:
tf.argmax函数根据axis的值进行 行索引或者列索引,axis取值[-2,2),半开半闭区间,
当axis = 0或-2时 (按列分割),(按行对比)
当axis = 1或-1时 (按行分割),(按列对比)