Tensorflow MNIST原始图片TFRecord方式识别---4. 取多张手写数字图片进行测试
本章节是《Tensorflow MNIST原始图片TFRecord方式识别》的最后一节,验证模型,测试模型识别准确率。
1. 测试图片数据预处理
验证训练模型的准确率,首先需要将测试图片作预处理,转为像素矩阵。测试图片数据预处理需要满足如下条件:
- 像素矩阵算法一致性
训练样本的像素矩阵进行了怎样的算法运算,测试样本的像素矩阵需要进行同样的算法运算;比如训练样本 image_data乘以1.0/255, 那么测试样本也需要 test_data乘以1.0/255。
- shape一致性
和每个训练样本的shape一样
import tensorflow as tf
import inference
import os
import cv2
import numpy as np
#神经网络相关参数
BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.01
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 6000
MOVING_AVERAGE_DECAY = 0.99
def image_prepare(file_name):
image_data = cv2.imread(file_name)
image = cv2.cvtColor(image_data, cv2.COLOR_RGB2GRAY)
#因为模型训练时,训练样本数据像素矩阵经过如下算法
image_raw = image * 1.0/255
# image_raw = (255-image) * 1.0/255
image_raw = image_raw.astype(np.float32)
image_raw = np.reshape(image_raw, [28,28,1])
return image_raw
如下是训练样本的图例:
如果测试样本是如下图例,则需要改图像数据的算法为image_raw = (255-image) * 1.0/255;而且size也得处理成一致的。
2. 测试已训练模型准确率
使用tf.train.Saver()类的restore接口,加载已经训练好的模型,从输入测试图片数据到加载模型,中间过程需要和训练时保持一致【有的文章这样讲,但是不需要】。
测试接口的代码,如下:
def test(pic_data):
x = tf.placeholder(tf.float32, [
1,
28,
28,
1],
name='x-input')
y_ = tf.placeholder(tf.float32, [None, 10], name='y-input')
# L2正则化
regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
y = inference.inference(x,False,regularizer)
# 初始化TensorFlow持久化类。
saver = tf.train.Saver()
with tf.Session() as sess:
# 启动多线程处理输入数据, 将样本数据填入到队列,为训练读取数据做好准备
# 否则训练过程会一致堵塞,处于等待数据的状态。
# 采用Coordinator对象为了,当这些线程发生异常时,关闭这些线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
saver.restore(sess, '../src/LeNet_Mnist_Origin/ckpt_dir/hand_shuffle_model.ckpt')
# print("y:", sess.run(y, feed_dict = {x: [pic_data]}))
#经过CNN得到一维向量,长度为10(0-9十个分类的计算值),最大元素值的index下标值为识别结果。
predict_result = sess.run(tf.arg_max(y,1), feed_dict = {x: [pic_data]})
predict_res = predict_result[0]
coord.request_stop()
coord.join(threads)
return predict_res
if __name__ == '__main__':
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
pic_path = '../datasets/MNIST_PNG_Data/test/'
'''
pic_path = 'C:/Users/Administrator/Pictures/00022.png'
print(pic_path)
pic_data = image_prepare(pic_path)
result = test(pic_data)
print("识别结果: ", result)
'''
# 这个识别的准确率还是比较高的。
for num in range(10):
i = 0
j = 0
pic_path1 = os.path.join(pic_path, str(num))
for file in os.listdir(pic_path1):
file_name = os.path.join(pic_path1,file)
pic_data = image_prepare(file_name)
result = test(pic_data)
tf.reset_default_graph()
if result == num:
# print("第%d张图片,识别结果成功: %d" %(i, result))
i += 1
else:
j += 1
print("%d 的测试样本数: %d, 其识别准确率为: %f" %(num, i+j, i/(i+j)))
测试结果如下,准确率和使用tensorflow内置标准MNIST数据集时,不分上下。