今天写代码时,发现以前对Top_k理解有误,现在,重新整理对此的理解:
correct = tf.nn.in_top_k(logits, labels, k)
其中:
logits: a tensor of shape [batch_size, NUM_CLASSES]
labels: a tensor of shape [batch_size]
理解:
1.对于logits的某行logits[i],找到其前k个最大的预测值的index_0, .., index_k-1,
如果发现对应的labels[i]在{index_0, …, index_k-1}, 则返回True.(大致这个意思)
2.当k=1时,等价于tf.equal(logits, labels)。但是,equal()函数中的logits和labels的shape必须一样。因此,通过read_data_sets(…, one_hot=False,…)读取数据时,必须使得one_hot=Ture(默认为False).
下面为源码中对此的解释:
This outputs a batch_size
bool array, an entry out[i]
is true
if the prediction for the target class is among the top k
predictions among all predictions for example i
.
注意:
tf.nn.top_k(input, k=1, sorted=sorted, name)
在top_k()函数中,返回的是两个值:
values: input中最后一维部分的前k个最大的值。
indices:与values中各个值对应的索引。
(sorted默认为1, 即根据输出values进行降序输出)