Tensorflow如何使用自己cifar10训练模型检测一张任意的图片
研究了cifar10数据集1个月了,终于实现了cifar10训练模型验证一张图片的全部过程,网上给的例子要么caffe实现,要么就是说半截或者是给的代码不全的,实在是无语。自己将我的探究成果写个博客,希望能帮助更多的人少走一些弯路。本博客的代码在官方的例子基础上进行的改版
-
cifar10数据集的简单介绍
共分为10类,具体的分类如下图所示:
60000张图片里面有:
50000张训练样本
10000张测试样本(验证Set)图片是三通道RGB的彩色图片,大小是32x32像素,3*32*32==3*1024==3072,存储在numpy的时候,前1024位是RGB中的R分量像素值,
中间的1024位是G分量的像素值,最后的1024是B分量的像素值
-
cifar10验证单张图片
1. 验证单张图片,你需要先处理读取的图片,将其处理成 [batch_size, height, width, channels]四维的tensor
2. 调用cifar10.py 中的 inference 函数,对输入图片进行卷积、池化、本地化等操作,之后获取最终的 logits
3. 加载训练模型时保存的恢复点,并对测试图片进行预测
首先需要安装prettytable库
sudo pip install prettytable
要实现上诉的操作,常用的有类似MNIST使用placeholder的方式,这种方式可以参看一个不完整的样例 How to classify images using tensorflow cifar10 model
我实现的方式是另一种方式,代码如下:
# -*- coding:utf-8 -*- import tensorflow as tf from tensorflow.python.ops.image_ops_impl import ResizeMethod from prettytable import PrettyTable import cifar10 import numpy as np FLAGS = tf.app.flags.FLAGS # 设置存储模型训练结果的路径 tf.app.flags.DEFINE_string('checkpoint_dir', '/home/xzy/cifar10_train_xzy', """Directory where to read model checkpoints.""") tf.app.flags.DEFINE_string('class_dir', '/home/xzy/cifar10-input/cifar-10-batches-bin/', """存储文件batches.meta.txt的目录""") tf.app.flags.DEFINE_string('test_file', '/home/xzy/dog.jpg', """测试用的图片""") IMAGE_SIZE = 24 def evaluate_images(images): # 执行验证 logits = cifar10.inference(images, batch_size=1) load_trained_model(logits=logits) def load_trained_model(logits): with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: # 从训练模型恢复数据 saver = tf.train.Saver() saver.restore(sess, ckpt.model_checkpoint_path) else: print('No checkpoint file found') return # 下面两行是预测最有可能的分类 # predict = tf.argmax(logits, 1) # output = predict.eval() # 从文件以字符串方式获取10个类标签,使用制表格分割 cifar10_class = np.loadtxt(FLAGS.class_dir + "batches.meta.txt", str, delimiter='\t') # 预测最大的三个分类 top_k_pred = tf.nn.top_k(logits, k=3) output = sess.run(top_k_pred) probability = np.array(output[0]).flatten() # 取出概率值,将其展成一维数组 index = np.array(output[1]).flatten() # 使用表格的方式显示 tabel = PrettyTable(["index", "class", "probability"]) tabel.align["index"] = "l" tabel.padding_width = 1 for i in np.arange(index.size): tabel.add_row([index[i], cifar10_class[index[i]], probability[i]]) print tabel def img_read(filename): if not tf.gfile.Exists(filename): tf.logging.fatal('File does not exists %s', filename) image_data = tf.image.convert_image_dtype(tf.image.decode_jpeg(tf.read_file(filename), channels=3), dtype=tf.float32) height = IMAGE_SIZE width = IMAGE_SIZE image = tf.image.resize_images(image_data, (height, width), method=ResizeMethod.BILINEAR) image = tf.expand_dims(image, -1) image = tf.reshape(image, (1, 24, 24, 3)) return image def main(argv=None): # pylint: disable=unused-argument filename = FLAGS.test_file images = img_read(filename) evaluate_images(images) if __name__ == '__main__': tf.app.run()
上诉的代码,给定一张狗的图片,显示最大的三类识别的结果,最后是用prettytabel打印出的效果由于训练的样本设置成100,模型的准确率反而cat的较高,这很正常+-------+-------+-------------+ | index | class | probability | +-------+-------+-------------+ | 3 | cat | 0.530751 | | 5 | dog | 0.491245 | | 2 | bird | 0.139152 | +-------+-------+-------------+
注意:
1. 对inference函数加入一个参数,让其默认的值是原来的128,验证单张的时候传入1,这样不会影响原来的测试样本集的验证
2. 使用 np.loadtxt 来读取cifar10的10个类,方便后续知道下标,获取cifar10的种类名称
3. tf.nn.top_k(logits, k=3) 来显示最大的3个类别的概率和index下标
4. 使用prettytable格式化输出
完整的代码,见本人github https://github.com/xzy256/cifar10_xzy ,欢迎评论
参考文献
将二进制转换成图片
tensorflow学习之识别单张图片的实现(python手写数字)
tensorflow实现embedding展示的简单快速构建例子
tensorflow使用cifar10模型进行验证单张图片的代码
tensorfloe模块化函数实现cifar10数据集上测试单张图片
cifar10获取10个类别的方法