tf.nn.in_top_k()-解析,以及不适用范围

1.in_top_k(predictions, targets, k, name=None)

Args:

predictions: 一种tf.float的张量。一个batch_size的x类张量预测值,one-hot编码,size为[batch_size,label类别数]

如在cifar10的分类上为[128,10]

targets: 一个张量。必须是下列类型之一:int32, int64。size只有一维,也就意味着不能是one-hot编码的。理由举例就知道了

k:每个样本的预测结果的前k个最大的数里面是否包含targets预测中的标签,一般都是取1,即取预测最大概率的索引与标签对比。

name : 操作的名称(可选)。

举例:假设预测值logits为【10,5】的张量,5表示预测为5个类别,labels就为【10】

import tensorflow as tf

logits = tf.Variable(tf.random_normal([10,5],mean=0.0,stddev=1.0,dtype=tf.float32))
labels = tf.constant([0,2,0,1,0,0,4,0,3,0])
top_1_op = tf.nn.in_top_k(logits,labels,1)
top_2_op = tf.nn.in_top_k(logits,labels,2)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(logits.eval())
    print(labels.eval())
    print(top_1_op.eval())
    print(top_2_op.eval())

结果:


解读第一个top_1_op.eval()值False的来源;

首先看第一行,前1个最大的值的索引为1,而labels第一个值为0,不想等,所以为False.以此类推.......

解读第一个top_2_op.eval()值False的来源;

首先看第一行,前2个最大的值的索引分别为1和0,而labels第一个值为0,有一个与labels相等,所以为True.以此类推.......

从这个过程我们就可以知道,labels如果也是一个one-hot编码的话,即使找到logits前一个最大值的索引,你要同labels(假设为【0,0,1,0,0】)去比较值相等,显然是不可能的,因为labels本身就不是一个值,而是一个列表,你怎么将一个数和一个列表比较相不相等呢?所以,用这种方法labels是不能够用one-hot编码的。

举个错误的例子,将这里的labels改为one-hot编码。看看报错怎么样。

import tensorflow as tf

logits = tf.Variable(tf.random_normal([10,5],mean=0.0,stddev=1.0,dtype=tf.float32))
labels = tf.constant([0,2,0,1,0,0,4,0,3,0])
n_classes = 5
labels = tf.one_hot(labels, depth=n_classes)
print(labels)
top_1_op = tf.nn.in_top_k(logits,labels,1)
top_2_op = tf.nn.in_top_k(logits,labels,2)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(logits.eval())
    print(labels.eval())
    print(top_1_op.eval())
    print(top_2_op.eval())

Tensor("one_hot:0", shape=(10, 5), dtype=float32)

TypeError: Value passed to parameter 'targets' has DataType float32 not in list of allowed values: int32, int64







阅读更多
文章标签: in_top_k
个人分类: tensorflow
上一篇Kevin Xu-TensorFlow Tutorials-cifar10 (2)
下一篇tensorflow实现AlexNet-黄文坚版本
想对作者说点什么? 我来说一句

没有更多推荐了,返回首页

关闭
关闭