基于一个线性层的softmax回归模型和MNIST数据集识别自己手写数字

原博文是用cnn识别,因为我是在自己电脑上跑代码,用不了处理器,所以参考Mnist官网上的一个线性层的softmax回归模型的代码,把两篇文章结合起来识别。

最后效果

在这里插入图片描述在这里插入图片描述
在这里插入图片描述在这里插入图片描述
在这里插入图片描述在这里插入图片描述
在这里插入图片描述在这里插入图片描述
源代码识别mnist数据集的准确率是91.5%,我改了之后,识别自己手写数字的话,4,8,9比较难识别,其他数字基本可以一次成功。无法识别的原因,一个应该是算法本身精度不够,一个是数据集是欧美手写体,和东方手写体有差异,这个可以看我上面写的5比较夸张。。。

先附上

原博文

这个是用cnn识别的:https://blog.csdn.net/qq_38269418/article/details/78991649
mnist官网代码
http://www.tensorfly.cn/tfdoc/tutorials/mnist_pros.html
我是根据这两篇文章进行复现的。

第一步训练和保存模型

先对模型进行训练和保存,先新建文件夹,我和原博一样,文件夹命名是SAVE。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

这上面几行代码调用mnist数据集,用这几行代码就不需要去官网下载了

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

import tensorflow as tf
sess = tf.InteractiveSession()

x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])

W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

sess.run(tf.initialize_all_variables())
y = tf.nn.softmax(tf.matmul(x,W) + b)
cross_entropy = -tf.reduce_sum(y_*tf.log(y))

saver = tf.train.Saver() #定义saver

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
for i in range(1000):
  batch = mnist.train.next_batch(50)
  train_step.run(feed_dict={x: batch[0], y_: batch[1]})
  saver.save(sess, 'D:\PycharmProject\mnist_model\SAVE\model.ckpt') #模型储存位置

  correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
  print (accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

第一步成功的话,可以看到这个算法精度是91.5%
在这里插入图片描述
在新建的SAVE文件夹也可以看到这几个文件
在这里插入图片描述
可能没有那么多,因为我跑了好几遍代码,改了几次,文件有多出来。。。
如果可以看到精度和文件就说明第一步成功。

第二步 准备手写数字

这一步我和原博一样是用photoshop做的,28*28像素,RGB,3像素画笔。按照原文来就可以。最后以png格式和项目保存到同一个文件夹。

第三步 识别图片

前两步和原博差不多,第三步有一些要改动。

from PIL import Image, ImageFilter
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

如果一开始没装过PIL,需要下载一下,cmd输入pip install pillow,或者直接在pycharm里面下载。

import tensorflow as tf
sess = tf.InteractiveSession()
def imageprepare():
    im = Image.open('D:/PycharmProject/mnist_model/7.png') #读取的图片所在路径,注意是28*28像素
    plt.imshow(im)  #显示需要识别的图片
    plt.show()
    im = im.convert('L')
    tv = list(im.getdata())
    tva = [(255-x)*1.0/255.0 for x in tv]
    return tva

result=imageprepare()
x = tf.placeholder(tf.float32, [None, 784])
y_ = tf.placeholder(tf.float32, [None, 10])

到这边都和原博差不多,调用不同图片,就改一下文件名就可以。下面的算法就不一样了,采用的是一个线性层的softmax回归模型,自己电脑也能跑,挺快的。

W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))

sess.run(tf.initialize_all_variables())
y = tf.nn.softmax(tf.matmul(x,W) + b)
cross_entropy = -tf.reduce_sum(y_*tf.log(y))

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
for i in range(1000):
  batch = mnist.train.next_batch(50)
  train_step.run(feed_dict={x: batch[0], y_: batch[1]})


correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

saver = tf.train.Saver()  # 定义saver



with tf.Session() as sess:
      sess.run(tf.global_variables_initializer())
      saver.restore(sess, 'D:\PycharmProject\mnist_model\SAVE\model.ckpt')

这个参考的是mnist官网代码,把模型的保存地址改一下就行。

  prediction = tf.argmax(y, 1)
      predint = prediction.eval(feed_dict={x: [result]}, session=sess)

这两行代码是需要自己改的,原博是y_conv,我没有用神经网络算法,改成y,一开始我改的是y_结果第二行代码一直报错,还找不到为什么错,哭唧唧,找了大神师兄帮我看代码才找到问题,师兄永远是师兄!
错误的原因是:y_是占位符,需要给他提供实际的标签,所以y_是实际的,y才是神经网络的输出,要把计算出来的y和实际的y_去比较,所以第一句才是用y,因为是预测的。
第二行,原博有三个占位符,{}要输两个,我只有两个占位符,所以输入图片结果就可以。

    print('识别结果:')
      print(predint[0])

最后输出识别结果。
在这里插入图片描述在这里插入图片描述

至此,大功告成!

谨以此文纪念一下小菜鸟第一个跌跌撞撞终于成功的并不太复杂的程序!

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值