原文链接:https://blog.csdn.net/u013580539/article/details/79339250
关于tf.argmax,我看到网上的资料有些杂乱难以理解,所以写这篇文章。在tf.argmax( , )中有两个参数,第一个参数是矩阵,第二个参数是0或者1。0表示的是按列比较返回最大值的索引,1表示按行比较返回最大值的索引。下面上代码:
import tensorflow as tf
Vector = [1,1,2,5,3] #定义一个向量
X = [[1,3,2],[2,5,8],[7,5,9]] #定义一个矩阵
with tf.Session() as sess:
a = tf.argmax(Vector, 0)
b = tf.argmax(X, 0)
c = tf.argmax(X, 1)
print(sess.run(a))
print(sess.run(b))
print(sess.run(c))
输出的结果是:
3
[2 1 2]
[1 2 2]
【-----------------------------------2019.2.25补充-------------------------------------】
其实我们操作的多维数组,比如矩阵:
[[1, 2, 3]
[ 4, 5, 6]]
它的形状为(2, 3),其中第一个数字为“2”,第二个数字为“3”,也就是说(2, 3)这个形状的索引[0]是“2”,索引[1]是“3”,而我们这里是一个两行三列的矩阵。说到这里就得再牵扯一下n维数组,n维数组的形状是 (行数, 列数, …),那么既然如此,我们argmax()的第二个参数为0的时候表示按列比较,为1的时候表示按行比较。。。。。。。。。。
对,这里的0和1,表示的就是形状的索引!参数为0时,是根据(2, 3)0[0],即行的方向进行比较,参数为1时,是按照(2, 3)1[1],即列的方向进行比较!实际上,这里的“行”,也即为我们的样本数量,毕竟我们都知道,在tensorflow中我们的数据格式下标[0]的是batchsize
实验验证我们的理论:
import tensorflow as tf
x = [[1, 2, 3],[ 4, 5, 6]]
with tf.Session() as sess:
d1 = tf.argmax(x, 0)
d2 = tf.argmax(x, 1)
print(sess.run(d1))
print(sess.run(d2))
先猜一猜输出结果,再查看答案哟 ~~
答案是:
[1 1 1]
[2 2]
你,做对了吗?
到现在呢,可能大家还有一个疑问:对于Vector = [1,1,2,5,3],为什么argmax的参数写0呢,这里不是应该按照1的方向比较吗?
其实啊,这里的Vector是向量,并不是矩阵,而向量是只有一个维度的,所以,我们argmax中的参数就只能为0