Tensorflow(1):MNIST识别自己手写的数字--入门篇(Softmax回归)

  机器学习入门都是从MNIST开始,Tensorflow官方社区提供了十分详细的教程【MNIST机器学习入门】。但是我们显然不满足于仅仅把官方的代码复制一遍然后输出个结果,我们想能不能实现自己手写数字的识别。
  本文作为Tensorflow入门,结合官方代码,利用Softmax回归函数,实现模型的训练、保存、以及重新加载,完成对自己手写数字的识别。

1.模型训练及保存

  模型我们采用Softmax回归函数,具体代码参考【MNIST机器学习入门】,这里用梯度下降算法以0.01学习率最小化交叉熵对模型进行1000次训练。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)  # 插入数据

# name在保存模型时非常有用
x = tf.placeholder("float", [None, 784], name='x') 
W = tf.Variable(tf.zeros([784, 10]), name='W')  
b = tf.Variable(tf.zeros([10]), name='b')
y = tf.nn.softmax(tf.matmul(x, W) + b, name='y')  # y预测概率分布

y_ = tf.placeholder("float", [None, 10])  # y_实际概率分布

cross_entropy = -tf.reduce_sum(y_ * tf.log(y))  # 交叉熵
# 梯度下降算法以0.01学习率最小化交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)  
init = tf.initialize_all_variables()  # 初始化变量

sess = tf.Session()
sess.run(init)
saver = tf.train.Saver()

for i in range(1000):  # 开始训练模型,循环1000次
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
saver.save(sess, 'minst_model.ckpt')  # 保存模型

  在代码最前面,定义张量(变量)时,我们给每个张量(变量)都加了name关键字,这个对我们后期再次加载模型很重要。
  在最后,我们利用saver.save()函数,保存模型。模型名称为minst_model.ckpt。之后我们可以在文件夹下看到4个文件:

  • checkpoint: 保存目录下所有模型文件列表
  • minst_model.ckpt.meta :保存了计算图的结构,可以理解为模型的结构
  • minst_model.ckpt.index 和 minst_model.ckpt.data-00000-of-00001:保存了模型中所有变量的值.

2.模型加载

   保存好模型之后,我们利用自己的图片对模型进行测试。我们利用windows自带的画图软件,进行数字手写,并保存成28*28像素的png图片。例如0,1,2手写体图片,如下图所示:


     

  整个测试代码如下:

from PIL import Image, ImageFilter
import tensorflow as tf

def imageprepare():
    file_name = 'pic/2-3.png'  # 图片路径
    myimage = Image.open(file_name).convert('L')  # 转换成灰度图
    tv = list(myimage.getdata())  # 获取像素值
    # 转换像素范围到[0 1], 0是纯白 1是纯黑
    tva = [(255-x)*1.0/255.0 for x in tv] 
    return tva

result = imageprepare()
init = tf.global_variables_initializer()
saver = tf.train.Saver

with tf.Session() as sess:
    sess.run(init)
    saver = tf.train.import_meta_graph('minst_model.ckpt.meta')  # 载入模型结构
    saver.restore(sess,  'minst_model.ckpt')  # 载入模型参数

    graph = tf.get_default_graph()  # 计算图
    x = graph.get_tensor_by_name("x:0")  # 从模型中获取张量x
    y = graph.get_tensor_by_name("y:0")  # 从模型中获取张量y

    prediction = tf.argmax(y, 1)
    predint = prediction.eval(feed_dict={x: [result]}, session=sess)
    print(predint[0])

  在加载模型时,我们先用tf.train.import_meta_graph()载入模型的结构,之后利用saver.restore()加载模型的训练好的参数。graph.get_tensor_by_name()依照名字(name)从模型中获取张量。所以前面在保存模型时我们给每个张量和变量都加了name关键字
  关于如何保存和加载训练模型可以参见博客【TensorFlow保存还原模型的正确方式】

3.识别结果

  输出的识别结果如下所示:


    

     

    

  经测试,该方法基本识别率可以达到90%左右。所以基本可以满足要求。

4.注意事项

  最早时,我手写数字进行识别时,发现准确率很低。
  后来发现原因是:(1)我自己手动画的数字线条太细了;(2)画的有些数字在图片中的位置没有位于中心;(3)训练集是西方的手写数字,和中国的手写数字习惯不同。下面是官方的训练数据中的部分数字。


  在画图时,数字效果(画笔粗细等)尽量和上面训练集保持一致,就会得到较高的识别率!
  是以为记!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值