TensorFlow手写数字识别(三)

  • 在上一篇输出预测错误图片的基础上,本文将继续学习,目的是预测自己用mspaint手写的数字

  • 首先在项目工作文件中创建文件夹:mypaint,然后自己制作手写数字预测图片

  • 下面的代码(get_myData.py)先对自己制作的图片进行预处理,最后得到的是和mnist数据集一样的格式

import os
import cv2
import numpy as np


def get_data():
    path = "mypaint"

    #定义一个列表用于存放多个1*784的矩阵
    Stack = []

    for filename in os.listdir(path):

        print(filename)

    	#读取文件夹中每一个图片
        img = cv2.imread(path+"/"+filename)
        #cv2.imshow("test_imread",img)
        #cv2.waitKey()
        #print(img.shape)

        #调整图片大小至固定尺寸28*28
        img = cv2.resize(img,(28,28),)

        #print(array_of_img)

        #将图片转为灰度图
        img = cv2.cvtColor(img,cv2.COLOR_RGB2GRAY)

        #将图片反色
        img = 255-img

        #将0~255之间的uint8类型的数转换成0~1之间的浮点数
        img = img/255
      	
      	#把二阶28*28的张量变成一阶的1*784张量
        row = img.reshape((784,))

        #将上面每个1*784的张量添加到列表中
        Stack.append(row)

    #下面将列表Stack中的每个元素堆叠成一个矩阵
    picture = np.stack(Stack)

    #print(picture)
    #print(picture.shape)
    return picture

  • 下面的代码对minist数据集进行训练,并对上面处理过的图片进行预测
import get_myData
import tensorflow as tf
import input_data
import numpy as np
#读取数据
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
#创建可交互的操作单元
x = tf.placeholder("float",[None,784])
w = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

#实现回归模型,其中y是预测值,并且y是一个m*10的矩阵
y = tf.nn.softmax(tf.matmul(x,w) + b)

#将上面得到的m*10的张量y后面加一个结点,y_out得到的是1*m的一阶张量
y_out = tf.argmax(y,1)
##训练模型
#为了计算交叉熵,首先添加一个新的占位符用于输入正确值,并且y_是一个m*10的矩阵
y_ = tf.placeholder("float",[None,10])

#使用公式  -Σy'log(y)  计算交叉熵
cross_entropy = -tf.reduce_sum(y_*tf.log(y))

#梯度下降
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

#上面已经把模型设置好了,在运行计算之前,下面先初始化创建的变量
init = tf.initialize_all_variables()

#现在我们可以在一个session里面启动模型,并且初始化向量
sess = tf.Session()
sess.run(init)

#下面开始训练模型,这里我们让模型循环1000次
#在循环的每个步骤中,都会随机抓取训练数据中的100个批处理数据点作为参数替换之前的占位符来运行train_step
for i in range(1000):
	#batch_xs是样本图片,batch_ys是样本的标签
	batch_xs, batch_ys = mnist.train.next_batch(100)
	sess.run(train_step, feed_dict = {x: batch_xs, y_: batch_ys})

#下面对自己手写的数据进行预测
#result = sess.run(y, feed_dict = {x:get_myData.get_data()})

#输出预测结果方法一:
#for i in range(result.shape[0]):
	#print(np.argmax(result[i]))

#输出预测结果方法二:
#print(np.argmax(result,1))

#输出预测结果方法三:
result = sess.run(y_out, feed_dict = {x:get_myData.get_data()})
print(result)
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值