在处理一系列分类问题时,经常会用到tf.nn.in_top_k这一函数,其调用形式如下tf.nn.in_top_k(predictions, targets, k, name=None)
其中:
predictions: 你的预测结果(一般也就是你的网络输出值)大小是预测样本的数量乘以输出的维度
target: 实际样本类别的标签,大小是样本数量的个数
k: 每个样本中前K个最大的数里面(序号)是否包含对应target中的值
下面我们结合例子,探究tf.nn.in_top_k的用法
import tensorflow as tf;
A = [[0.8,0.3,0.7], [0.1,0.9,0.5]]
B = [1, 2]
out = tf.nn.in_top_k(A, B, 2)#或out = tf.nn.in_top_k(A, B,1)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(out))
当k取1的时候:
A中第一个元素的最大值为0.8,索引(序号)是0,而B的第一个元素是1,0不等于1,所以返回False.
A中第二个元素的最大值为0.9,索引(序号)是1,而B的第而个元素是2,1不等于2,所以返回False.
当k取2的时候:
A中前两个元素的最大值为0.8,0.7,索引(序号)是0,2,而B的第一个元素是1,0与2都不等于1,所以返回False.
A中前两个元素的最大值为0.9,0.5,索引(序号)是1,2,而B的第一个元素是2,1与2中有2等于2,所以返回True.