对例子解释了一下,直接来吧:
tf.nn.top_k(input, k, name=None)
解释:这个函数的作用是返回 input 中每行最大的 k 个数,并且返回它们所在位置的索引。
import tensorflow as tf
import numpy as np
input = tf.constant(np.random.rand(3,4))
k = 2
"""输出的每行最大的k个数,还有k个数的索引等信息"""
output = tf.nn.top_k(input, k)
with tf.Session() as sess:
print(sess.run(input))
print(sess.run(output))
输出:
[[ 0.11658417 0.0049587 0.34396945 0.80061182]
[ 0.94435975 0.54798914 0.52284388 0.05966983]
[ 0.44605413 0.06890732 0.67666671 0.05019359]]
TopKV2(values=array([[ 0.80061182, 0.34396945],
[ 0.94435975, 0.54798914],
[ 0.67666671, 0.44605413]]), indices=array([[3, 2],
[0, 1],
[2, 0]]))
tf.nn.in_top_k(predictions, targets, k, name=None)
# Says whether the targets are in the top K predictions.
解释:这个函数的作用是返回一个布尔向量,说明目标值targets是否存在于预测值predictions最大的k个之中。
输出数据是一个 targets 长度的布尔向量,如果目标值存在于预测值之中,那么 out[i] = true。
注意:targets 是predictions中的索引位,并不是 predictions 中具体的值。
import tensorflow as tf
import numpy as np
"""给出targets=[1,1,1],里面的每个元素代表各行的索引,在这个索引上的数是否是在该行的最大k个数里。最常用在mnist数据集最后对标签分类的准确率上"""
test = np.array([[4,3,1,2],[2,3,4,1]])
y = np.array([2,2])
with tf.Session() as sess:
test_val= sess.run(tf.nn.in_top_k(test,y,1))
print(test_val)
输出:
#因为k=1,相当于问test[0][2]是不是这一行最大的,test[1][2]是不是这一行最大的,所以结果是这样
[False True]