-
在上一篇输出预测错误图片的基础上,本文将继续学习,目的是预测自己用mspaint手写的数字
-
首先在项目工作文件中创建文件夹:mypaint,然后自己制作手写数字预测图片
-
下面的代码(get_myData.py)先对自己制作的图片进行预处理,最后得到的是和mnist数据集一样的格式
import os
import cv2
import numpy as np
def get_data():
path = "mypaint"
#定义一个列表用于存放多个1*784的矩阵
Stack = []
for filename in os.listdir(path):
print(filename)
#读取文件夹中每一个图片
img = cv2.imread(path+"/"+filename)
#cv2.imshow("test_imread",img)
#cv2.waitKey()
#print(img.shape)
#调整图片大小至固定尺寸28*28
img = cv2.resize(img,(28,28),)
#print(array_of_img)
#将图片转为灰度图
img = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)
#将图片反色
img = 255-img
#将0~255之间的uint8类型的数转换成0~1之间的浮点数
img = img/255
#把二阶28*28的张量变成一阶的1*784张量
row = img.reshape((784,))
#将上面每个1*784的张量添加到列表中
Stack.append(row)
#下面将列表Stack中的每个元素堆叠成一个矩阵
picture = np.stack(Stack)
#print(picture)
#print(picture.shape)
return picture
- 下面的代码对minist数据集进行训练,并对上面处理过的图片进行预测
import get_myData
import tensorflow as tf
import input_data
import numpy as np
#读取数据
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
#创建可交互的操作单元
x = tf.placeholder("float",[None,784])
w = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
#实现回归模型,其中y是预测值,并且y是一个m*10的矩阵
y = tf.nn.softmax(tf.matmul(x,w) + b)
#将上面得到的m*10的张量y后面加一个结点,y_out得到的是1*m的一阶张量
y_out = tf.argmax(y,1)
##训练模型
#为了计算交叉熵,首先添加一个新的占位符用于输入正确值,并且y_是一个m*10的矩阵
y_ = tf.placeholder("float",[None,10])
#使用公式 -Σy'log(y) 计算交叉熵
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
#梯度下降
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
#上面已经把模型设置好了,在运行计算之前,下面先初始化创建的变量
init = tf.initialize_all_variables()
#现在我们可以在一个session里面启动模型,并且初始化向量
sess = tf.Session()
sess.run(init)
#下面开始训练模型,这里我们让模型循环1000次
#在循环的每个步骤中,都会随机抓取训练数据中的100个批处理数据点作为参数替换之前的占位符来运行train_step
for i in range(1000):
#batch_xs是样本图片,batch_ys是样本的标签
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict = {x: batch_xs, y_: batch_ys})
#下面对自己手写的数据进行预测
#result = sess.run(y, feed_dict = {x:get_myData.get_data()})
#输出预测结果方法一:
#for i in range(result.shape[0]):
#print(np.argmax(result[i]))
#输出预测结果方法二:
#print(np.argmax(result,1))
#输出预测结果方法三:
result = sess.run(y_out, feed_dict = {x:get_myData.get_data()})
print(result)