Tensorflow 如何使用自己cifar10训练模型检测一张任意的图片

Tensorflow如何使用自己cifar10训练模型检测一张任意的图片

研究了cifar10数据集1个月了,终于实现了cifar10训练模型验证一张图片的全部过程,网上给的例子要么caffe实现,要么就是说半截或者是给的代码不全的,实在是无语。自己将我的探究成果写个博客,希望能帮助更多的人少走一些弯路。本博客的代码在官方的例子基础上进行的改版


  • cifar10数据集的简单介绍

共分为10类,具体的分类如下图所示:
60000张图片里面有:
50000张训练样本
10000张测试样本(验证Set)

图片是三通道RGB的彩色图片,大小是32x32像素,3*32*32==3*1024==3072,存储在numpy的时候,前1024位是RGB中的R分量像素值,

中间的1024位是G分量的像素值,最后的1024是B分量的像素值

  • cifar10验证单张图片

1. 验证单张图片,你需要先处理读取的图片,将其处理成 [batch_size, height, width, channels]四维的tensor

2. 调用cifar10.py 中的 inference 函数,对输入图片进行卷积、池化、本地化等操作,之后获取最终的 logits

3. 加载训练模型时保存的恢复点,并对测试图片进行预测

首先需要安装prettytable库

sudo pip install prettytable

要实现上诉的操作,常用的有类似MNIST使用placeholder的方式,这种方式可以参看一个不完整的样例 How to classify images using tensorflow cifar10 model

我实现的方式是另一种方式,代码如下:

# -*- coding:utf-8 -*-
import tensorflow as tf
from tensorflow.python.ops.image_ops_impl import ResizeMethod
from prettytable import PrettyTable  

import cifar10
import numpy as np

FLAGS = tf.app.flags.FLAGS
# 设置存储模型训练结果的路径
tf.app.flags.DEFINE_string('checkpoint_dir', '/home/xzy/cifar10_train_xzy',
                           """Directory where to read model checkpoints.""")
tf.app.flags.DEFINE_string('class_dir', '/home/xzy/cifar10-input/cifar-10-batches-bin/',
                           """存储文件batches.meta.txt的目录""")
tf.app.flags.DEFINE_string('test_file', '/home/xzy/dog.jpg', """测试用的图片""")
IMAGE_SIZE = 24


def evaluate_images(images):  # 执行验证
    logits = cifar10.inference(images, batch_size=1)
    load_trained_model(logits=logits)


def load_trained_model(logits):
    with tf.Session() as sess:
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            # 从训练模型恢复数据
            saver = tf.train.Saver()
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            print('No checkpoint file found')
            return

        # 下面两行是预测最有可能的分类
        # predict = tf.argmax(logits, 1)
        # output = predict.eval()

        # 从文件以字符串方式获取10个类标签,使用制表格分割
        cifar10_class = np.loadtxt(FLAGS.class_dir + "batches.meta.txt", str, delimiter='\t')
        # 预测最大的三个分类
        top_k_pred = tf.nn.top_k(logits, k=3)
        output = sess.run(top_k_pred)
        probability = np.array(output[0]).flatten()  # 取出概率值,将其展成一维数组
        index = np.array(output[1]).flatten()
        # 使用表格的方式显示
        tabel = PrettyTable(["index", "class", "probability"])
        tabel.align["index"] = "l"  
        tabel.padding_width = 1 
        for i in np.arange(index.size):
            tabel.add_row([index[i], cifar10_class[index[i]], probability[i]])
        print tabel


def img_read(filename):
    if not tf.gfile.Exists(filename):
        tf.logging.fatal('File does not exists %s', filename)
    image_data = tf.image.convert_image_dtype(tf.image.decode_jpeg(tf.read_file(filename),
                                                                   channels=3), dtype=tf.float32)
    height = IMAGE_SIZE
    width = IMAGE_SIZE
    image = tf.image.resize_images(image_data, (height, width), method=ResizeMethod.BILINEAR)
    image = tf.expand_dims(image, -1)
    image = tf.reshape(image, (1, 24, 24, 3))
    return image


def main(argv=None):  # pylint: disable=unused-argument
    filename = FLAGS.test_file
    images = img_read(filename)
    evaluate_images(images)


if __name__ == '__main__':
    tf.app.run()

上诉的代码,给定一张狗的图片,显示最大的三类识别的结果,最后是用prettytabel打印出的效果

+-------+-------+-------------+
| index | class | probability |
+-------+-------+-------------+
| 3     |  cat  |   0.530751  |
| 5     |  dog  |   0.491245  |
| 2     |  bird |   0.139152  |
+-------+-------+-------------+
由于训练的样本设置成100,模型的准确率反而cat的较高,这很正常

注意:

1. 对inference函数加入一个参数,让其默认的值是原来的128,验证单张的时候传入1,这样不会影响原来的测试样本集的验证

2. 使用  np.loadtxt 来读取cifar10的10个类,方便后续知道下标,获取cifar10的种类名称

3. tf.nn.top_k(logits, k=3) 来显示最大的3个类别的概率和index下标

4. 使用prettytable格式化输出


完整的代码,见本人github  https://github.com/xzy256/cifar10_xzy   ,欢迎评论

参考文献

将二进制转换成图片                
tensorflow学习之识别单张图片的实现(python手写数字)           
tensorflow实现embedding展示的简单快速构建例子            
tensorflow使用cifar10模型进行验证单张图片的代码             
tensorfloe模块化函数实现cifar10数据集上测试单张图片     
cifar10获取10个类别的方法        


评论 26
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值