文件总览
- mnist_app.py:新生的文件
- mnist_forward.py:未变
- mnist_backward.py:未变
- mnist_test.py:未变
网络输入
- 网络输入:一维数组(784 个像素点)
- 像素点:0-1 之间的浮点数(接近 0 越黑,接近 1 越白)
网络输出
网络输出:一维数组(十个可能性概率),数组中最大的那个元素所对应的索引号就是预测的结果
关键处理
def application(): #输入要识别的几张图片(注意要给出待识别图片的路径和名称)
testNum = input("input the number of test pictures:")
for i in range(testNum):
testPic = raw_input("the path of test picture:")
testPicArr = pre_pic(testPic)
preValue = restore_model(testPicArr)
print ('The prediction number is:', preValue )
- 注解:
任务分成两个函数完成- testPicArr = pre_pic(testPic)对手写数字图片做预处理
- preValue = restore_model(testPicArr) 将符合神经网络输入要求的图片喂给复现的神经网络模型,输出预测值
具体代码
- mnist_app.py
#coding:utf-8
import tensorflow as tf
import numpy as np
from PIL import Image
import mnist_backward
import mnist_forward
def restore_model(testPicArr):
#创建一个默认图,在该图中执行以下操作(多数操作和train中一样)
with tf.Graph().as_default() as tg:
x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])
y = mnist_forward(x, None)
preValue = tf.argmax(y, 1) #计算求得输出 y,y 的最大值所对应的列表索引号就是预测结果
#实现滑动平均值,参数MOVING_AVRAGE_DECAY用于控制模型更新速度,训练过程中会对每个变量维护一个影子变量,该影子变量的初始值就是相应变量的初始值,每次变量更新时,影子变量随之更新
variable_averages = tf.ExpontialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
#断点续训
with tf.Session() as sess:
#通过checkpoint文件定位到最新保存的模型
ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
preValue = sess.run(preValue, feed_dict = {x:testPicArr})
return preValue
else:
print('No checkpoint file found')
return -1
#预处理函数,包括resize、转变灰度图、二值化操作
def pre_pic(picName):
img = Image.open(picName)
reIm = img.resize((28, 28), Image.ANTIALIAS)
im_arr = np.array(reIm.convert('L')) #对图片做二值化处理(这样以滤掉噪声,另外调试中可适当调节阈值)
threshold = 50 #设置合理阈值
for i in range(28): #遍历所有像素点
for j in range(28):
im_arr[i][j] = 255 - im_arr[i][j] #模型的要求是黑底白字,但输入的图是白底黑字,所以需要对每个像素点的值改为 255 减去原值以得到互补的反色
if (im_arr[i][j] < threshold):
im_arr[i][j] = 0 #像素小于阈值视为黑色
else:
im_arr[i][j] = 255
nm_arr = im_arr.reshape([1, 784]) #把图片形状拉成 1 行 784 列
nm_arr = nm_arr.astype(np.float32) #并把值变为浮点型(因为要求像素点是 0-1之间的浮点数)
img_ready = np.multiply(nm_arr, 1.0/255.0) #接着让现有的RGB图从 0-255 之间的数变为 0-1 之间的浮点数
return img_ready
def application(): #输入要识别的几张图片(注意要给出待识别图片的路径和名称)
testNum = input("input the number of test pictures:") #总共要待识别图片的个数
testNum = int(testNum)
for i in range(testNum):
testPic = input(r"the path of test picture:") #待识别图片的路径,路径前边加上r即可,禁止字符串转义
testPicArr = pre_pic(testPic)
preValue = restore_model(testPicArr)
print ('The prediction number is:', preValue )
def main():
application()
if __name__ == '__main__':
main()
实践
与MNIST 数据集输出手写数字识别准确率——实践相比,多了mnist_app.py文件
- 运行 mnist_forward.py
- 运行 mnist_backward.py
- 运行 mnist_test.py 来监测模型的准确率
- 运行 mnist_app.py 输入 10(表示循环验证十张图片)
- 先输入10
- 再输入input_data/0.jpg
问题:
运行 mnist_app.py时一直提示如下错误,望知道的小伙伴告知