Tensorflow MNIST原始图片TFRecord方式识别---4. 取多张手写数字图片进行测试

Tensorflow MNIST原始图片TFRecord方式识别---4. 取多张手写数字图片进行测试

本章节是《Tensorflow MNIST原始图片TFRecord方式识别》的最后一节,验证模型,测试模型识别准确率。

1. 测试图片数据预处理

验证训练模型的准确率,首先需要将测试图片作预处理,转为像素矩阵。测试图片数据预处理需要满足如下条件:

  • 像素矩阵算法一致性

训练样本的像素矩阵进行了怎样的算法运算,测试样本的像素矩阵需要进行同样的算法运算;比如训练样本 image_data乘以1.0/255, 那么测试样本也需要 test_data乘以1.0/255。

  • shape一致性

和每个训练样本的shape一样

import tensorflow as tf
import inference
import os
import cv2
import numpy as np

#神经网络相关参数
BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.01
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 6000
MOVING_AVERAGE_DECAY = 0.99

def image_prepare(file_name):
    image_data = cv2.imread(file_name)
    image = cv2.cvtColor(image_data, cv2.COLOR_RGB2GRAY)
    #因为模型训练时,训练样本数据像素矩阵经过如下算法
    image_raw = image * 1.0/255
    # image_raw = (255-image) * 1.0/255
    image_raw = image_raw.astype(np.float32)
    image_raw = np.reshape(image_raw, [28,28,1])
    return image_raw

如下是训练样本的图例:
在这里插入图片描述
如果测试样本是如下图例,则需要改图像数据的算法为image_raw = (255-image) * 1.0/255;而且size也得处理成一致的。
在这里插入图片描述

2. 测试已训练模型准确率

使用tf.train.Saver()类的restore接口,加载已经训练好的模型,从输入测试图片数据到加载模型,中间过程需要和训练时保持一致【有的文章这样讲,但是不需要】。
测试接口的代码,如下:

def test(pic_data):
    x = tf.placeholder(tf.float32, [
                1,
                28,
                28,
                1],
            name='x-input')
    y_ = tf.placeholder(tf.float32, [None, 10], name='y-input')
    # L2正则化
    regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
    y = inference.inference(x,False,regularizer)
    # 初始化TensorFlow持久化类。
    saver = tf.train.Saver()
    with tf.Session() as sess:
        # 启动多线程处理输入数据, 将样本数据填入到队列,为训练读取数据做好准备
        # 否则训练过程会一致堵塞,处于等待数据的状态。
        # 采用Coordinator对象为了,当这些线程发生异常时,关闭这些线程
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        saver.restore(sess, '../src/LeNet_Mnist_Origin/ckpt_dir/hand_shuffle_model.ckpt')
        # print("y:", sess.run(y, feed_dict = {x: [pic_data]}))
        #经过CNN得到一维向量,长度为10(0-9十个分类的计算值),最大元素值的index下标值为识别结果。
        predict_result = sess.run(tf.arg_max(y,1), feed_dict = {x: [pic_data]})
        predict_res = predict_result[0]
        coord.request_stop()
        coord.join(threads)
    
    return predict_res

if __name__ == '__main__':
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    pic_path = '../datasets/MNIST_PNG_Data/test/'
    '''
    pic_path = 'C:/Users/Administrator/Pictures/00022.png'
    print(pic_path)
    pic_data = image_prepare(pic_path)
    result = test(pic_data)
    print("识别结果: ", result)
    '''

    # 这个识别的准确率还是比较高的。
    for num in range(10):
        i = 0
        j = 0
        pic_path1 = os.path.join(pic_path, str(num))
        for file in os.listdir(pic_path1):
            file_name = os.path.join(pic_path1,file)
            pic_data = image_prepare(file_name)
            result = test(pic_data)
            tf.reset_default_graph()
            if result == num:
                # print("第%d张图片,识别结果成功: %d" %(i, result))
                i += 1
            else:
                j += 1
        print("%d 的测试样本数: %d, 其识别准确率为: %f" %(num, i+j, i/(i+j)))

测试结果如下,准确率和使用tensorflow内置标准MNIST数据集时,不分上下。
在这里插入图片描述
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值