tf.nn.in_top_k()-解析,以及不适用范围

原创 2018年04月16日 20:14:21

1.in_top_k(predictions, targets, k, name=None)

Args:

predictions: 一种tf.float的张量。一个batch_size的x类张量预测值,one-hot编码,size为[batch_size,label类别数]

如在cifar10的分类上为[128,10]

targets: 一个张量。必须是下列类型之一:int32, int64。size只有一维,也就意味着不能是one-hot编码的。理由举例就知道了

k:每个样本的预测结果的前k个最大的数里面是否包含targets预测中的标签,一般都是取1,即取预测最大概率的索引与标签对比。

name : 操作的名称(可选)。

举例:假设预测值logits为【10,5】的张量,5表示预测为5个类别,labels就为【10】

import tensorflow as tf

logits = tf.Variable(tf.random_normal([10,5],mean=0.0,stddev=1.0,dtype=tf.float32))
labels = tf.constant([0,2,0,1,0,0,4,0,3,0])
top_1_op = tf.nn.in_top_k(logits,labels,1)
top_2_op = tf.nn.in_top_k(logits,labels,2)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(logits.eval())
    print(labels.eval())
    print(top_1_op.eval())
    print(top_2_op.eval())

结果:


解读第一个top_1_op.eval()值False的来源;

首先看第一行,前1个最大的值的索引为1,而labels第一个值为0,不想等,所以为False.以此类推.......

解读第一个top_2_op.eval()值False的来源;

首先看第一行,前2个最大的值的索引分别为1和0,而labels第一个值为0,有一个与labels相等,所以为True.以此类推.......

从这个过程我们就可以知道,labels如果也是一个one-hot编码的话,即使找到logits前一个最大值的索引,你要同labels(假设为【0,0,1,0,0】)去比较值相等,显然是不可能的,因为labels本身就不是一个值,而是一个列表,你怎么将一个数和一个列表比较相不相等呢?所以,用这种方法labels是不能够用one-hot编码的。

举个错误的例子,将这里的labels改为one-hot编码。看看报错怎么样。

import tensorflow as tf

logits = tf.Variable(tf.random_normal([10,5],mean=0.0,stddev=1.0,dtype=tf.float32))
labels = tf.constant([0,2,0,1,0,0,4,0,3,0])
n_classes = 5
labels = tf.one_hot(labels, depth=n_classes)
print(labels)
top_1_op = tf.nn.in_top_k(logits,labels,1)
top_2_op = tf.nn.in_top_k(logits,labels,2)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(logits.eval())
    print(labels.eval())
    print(top_1_op.eval())
    print(top_2_op.eval())

Tensor("one_hot:0", shape=(10, 5), dtype=float32)

TypeError: Value passed to parameter 'targets' has DataType float32 not in list of allowed values: int32, int64







4用于cifar10的卷积神经网络-4.4/4.5cifar10数据集读取和数据增强扩充(上/下)

4用于cifar10的卷积神经网络-4.4/4.5cifar10数据集读取和数据增强扩充(上/下) 参考: https://github.com/tensorflow http://www.cs...
  • hongxue8888
  • hongxue8888
  • 2017-11-15 19:23:15
  • 187

多点透视cvSolvePNP的替代函数

在调试JNI程序时,所有的Shell都已经加载完成,而唯一真正核心的cv::SolvePnP却不能在JNI里面获得通行证,经过反复测试都不能运行,因此只能忍痛舍弃,自行编写一个具有相似功能的函数对其进...
  • BBZZ2
  • BBZZ2
  • 2016-09-13 10:19:20
  • 898

TensorFlow实战:Chapter-3(CNN-1-卷积神经网络简介)

卷积神经网络简介 CNN的提出 CNN的壮大 卷积神经网络结构 卷积神经网络的常见网络结构 卷积 信号处理中的卷积 图像处理中的卷积 卷积的类型的参数 卷积层 卷积层原理 卷积层算法 卷积层特点 权值...
  • u011974639
  • u011974639
  • 2017-07-19 12:14:05
  • 10134

富文本编辑+fs操作文件+Buffer练习(头像上传功能)

富文本编辑内容引用=>1.UEditor是由百度web前端研发部开发所见即所得富文本web编辑器下载的文件 引入目录文件进来:文件上传功能引用文件: require(‘../ueditor/’);...
  • qq_26766283
  • qq_26766283
  • 2017-06-06 00:42:35
  • 7472

ROS进二阶学习笔记(1) TF 学习笔记1:TF介绍 + tf工具

ROS TF 学习笔记(1) -- TF介绍 Ref:  http://wiki.ros.org/tf/Tutorials/Introduction%20to%20tf 惭愧的是,时隔10个月,重...
  • sonictl
  • sonictl
  • 2016-08-11 15:51:16
  • 2207

如何用Tensorflow实现增强版本的Mnist手写识别网络模型

上一个阶段构造的简单模型训练后,只有91%正确率。本文章讲解如何用一个稍微复杂的模型:卷积神经网络来提升效果。 下面是效果图: 图1具体的步骤为:首先完成准备工作:import inpu...
  • zjwcdd
  • zjwcdd
  • 2016-09-03 14:41:23
  • 905

OSGI学习笔记一(事件传递)

一、定义在jujianzh传递事件的类
  • waterbbx
  • waterbbx
  • 2014-11-07 18:35:08
  • 9636

TensorFlow入门(三)多层 CNNs 实现 mnist分类

之前在keras中用同样的网络和同样的数据集来做这个例子的时候。keras占用了 5647M 的显存(训练过程中设了 validation_split = 0.2, 也就是1.2万张图)。 但是我用 ...
  • c2a2o2
  • c2a2o2
  • 2017-04-21 13:26:25
  • 280

利用CNN(卷积神经网络)训练mnist数据集

本文参考了经典的LeNet-5卷积神经网络模型对mnist数据集进行训练。LeNet-5模型是大神Yann LeCun于1998年在论文"Gradient-based learning ap...
  • ChuiGeDaQiQiu
  • ChuiGeDaQiQiu
  • 2018-02-22 13:41:26
  • 126
收藏助手
不良信息举报
您举报文章:tf.nn.in_top_k()-解析,以及不适用范围
举报原因:
原因补充:

(最多只允许输入30个字)