一、tf.argmax() 的使用
简单的说,tf.argmax 就是返回最大的那个数值所在的下标
tf.argmax(input, dimension, name=None)
dimension=0 按列找
dimension=1 按行找
import tensorflow as tf
import numpy as np
A = np.array([[1,3,4,5,6]])
B = np.array([[1,3,4],[2,4,1]])
C = np.array([[3,1,2],[1,2,3],[4,6,5]])
with tf.Session() as sess:
# 按行寻找最大值的下标
print(sess.run(tf.argmax(A,1)))
# 按列寻找最大值的下标
print(sess.run(tf.argmax(B,0)))
print(sess.run(tf.argmax(C,1)))
二、 tf.equal() 的使用
tf.equal(A,B)
是对比这两个矩阵或者向量的相等的元素,
如果是相等的那就返回 True
,反正返回 False
,返回的值的矩阵维度和 A 是一样的
import tensorflow as tf
import numpy as np
A = np.array([[1,3,4,5,6]])
B = np.array([[1,3,4,5,6]])
C = np.array([[3,1,2,5,6]])
with tf.Session() as sess:
print(sess.run(tf.equal(A,B)))
print(sess.run(tf.equal(A,C)))