机器学习入门都是从MNIST开始,Tensorflow官方社区提供了十分详细的教程【MNIST机器学习入门】。但是我们显然不满足于仅仅把官方的代码复制一遍然后输出个结果,我们想能不能实现自己手写数字的识别。
本文作为Tensorflow入门,结合官方代码,利用Softmax回归函数,实现模型的训练、保存、以及重新加载,完成对自己手写数字的识别。
1.模型训练及保存
模型我们采用Softmax回归函数,具体代码参考【MNIST机器学习入门】,这里用梯度下降算法以0.01学习率最小化交叉熵对模型进行1000次训练。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True) # 插入数据
# name在保存模型时非常有用
x = tf.placeholder("float", [None, 784], name='x')
W = tf.Variable(tf.zeros([784, 10]), name='W')
b = tf.Variable(tf.zeros([10]), name='b')
y = tf.nn.softmax(tf.matmul(x, W) + b, name='y') # y预测概率分布
y_ = tf.placeholder("float", [None, 10]) # y_实际概率分布
cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) # 交叉熵
# 梯度下降算法以0.01学习率最小化交叉熵
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
init = tf.initialize_all_variables() # 初始化变量
sess = tf.Session()
sess.run(init)
saver = tf.train.Saver()
for i in range(1000): # 开始训练模型,循环1000次
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
saver.save(sess, 'minst_model.ckpt') # 保存模型
在代码最前面,定义张量(变量)时,我们给每个张量(变量)都加了name关键字,这个对我们后期再次加载模型很重要。
在最后,我们利用saver.save()函数,保存模型。模型名称为minst_model.ckpt。之后我们可以在文件夹下看到4个文件:
- checkpoint: 保存目录下所有模型文件列表
- minst_model.ckpt.meta :保存了计算图的结构,可以理解为模型的结构
- minst_model.ckpt.index 和 minst_model.ckpt.data-00000-of-00001:保存了模型中所有变量的值.
2.模型加载
保存好模型之后,我们利用自己的图片对模型进行测试。我们利用windows自带的画图软件,进行数字手写,并保存成28*28像素的png图片。例如0,1,2手写体图片,如下图所示:
整个测试代码如下:
from PIL import Image, ImageFilter
import tensorflow as tf
def imageprepare():
file_name = 'pic/2-3.png' # 图片路径
myimage = Image.open(file_name).convert('L') # 转换成灰度图
tv = list(myimage.getdata()) # 获取像素值
# 转换像素范围到[0 1], 0是纯白 1是纯黑
tva = [(255-x)*1.0/255.0 for x in tv]
return tva
result = imageprepare()
init = tf.global_variables_initializer()
saver = tf.train.Saver
with tf.Session() as sess:
sess.run(init)
saver = tf.train.import_meta_graph('minst_model.ckpt.meta') # 载入模型结构
saver.restore(sess, 'minst_model.ckpt') # 载入模型参数
graph = tf.get_default_graph() # 计算图
x = graph.get_tensor_by_name("x:0") # 从模型中获取张量x
y = graph.get_tensor_by_name("y:0") # 从模型中获取张量y
prediction = tf.argmax(y, 1)
predint = prediction.eval(feed_dict={x: [result]}, session=sess)
print(predint[0])
在加载模型时,我们先用tf.train.import_meta_graph()载入模型的结构,之后利用saver.restore()加载模型的训练好的参数。graph.get_tensor_by_name()依照名字(name)从模型中获取张量。所以前面在保存模型时我们给每个张量和变量都加了name关键字。
关于如何保存和加载训练模型可以参见博客【TensorFlow保存还原模型的正确方式】
3.识别结果
输出的识别结果如下所示:
经测试,该方法基本识别率可以达到90%左右。所以基本可以满足要求。
4.注意事项
最早时,我手写数字进行识别时,发现准确率很低。
后来发现原因是:(1)我自己手动画的数字线条太细了;(2)画的有些数字在图片中的位置没有位于中心;(3)训练集是西方的手写数字,和中国的手写数字习惯不同。下面是官方的训练数据中的部分数字。
在画图时,数字效果(画笔粗细等)尽量和上面训练集保持一致,就会得到较高的识别率!
是以为记!