argmax()官方文档如下:
tf.argmax(input, dimension, name=None)
Returns the index with the largest value across dimensions of a tensor.
Args:
input: A Tensor. Must be one of the following types: float32, float64, int64, int32, uint8, int16, int8, complex64, qint8, quint8, qint32.
dimension: A Tensor of type int32. int32, 0 <= dimension < rank(input). Describes which dimension of the input Tensor to reduce across. For vectors, use dimension = 0.
name: A name for the operation (optional).
Returns:
A Tensor of type int64.
dimension=0 按列找
dimension=1 按行找
tf.argmax()返回最大数值的下标
通常和tf.equal()一起使用,计算模型准确度
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
import numpy as np
x = np.array([1, 2])
y = np.array([[1],[2]])
print x.shape
print y.shape
>>>
(2,)
(2, 1)
[1,2]的shape值(2,),意思是一维数组,数组中有2个元素。
[[1],[2]]的shape值是(2,1),意思是一个二维数组,每行有1个元素。
[[1,2]]的shape值是(1,2),意思是一个二维数组,每行有2个元素。