tf.argmax(input,axis): 求input的axis维度上最大值的所以,同时对input降维,消掉axis的维度。比如input的shape为(1,3,4),axis=2,那么输出的shape为(1,3),若axis=1,那么输出的shape为(3,4)。有的博客上说axis=0是怎么消,axis=1是怎么消除,其实并不能帮助我们解决更高维度,所以我们可以根据输出的维度,来判断是怎么进行计算的
该函数可以用于神经网络分类中,最后预测样本的类型时使用,值最大的索引就是样本的类别编号。所以可以首先建立一个样本编号和样本类型之间的字典。
测试代码:
>>> a=np.array([[[1,2,3],[4,5,6]],[[2,3,1],[3,2,3]]]
... )
>>> tf.argmax(a,2)
<tf.Tensor 'ArgMax:0' shape=(2, 2) dtype=int64>
>>> sess=tf.Session()
>>> sess.run(tf.argmax(a,2))
array([[2, 2],
[1, 0]], dtype=int64)
>>> sess.run(tf.argmax(a,1))
array([[1, 1, 1],
[1, 0, 1]], dtype=int64)
>>> sess.run(tf.argmax(a,0))
array([[1, 1, 0],
[0, 0, 0]], dtype=int64)