本文我们利用python语言,通过tensorflow框架对手写字体MNIST数据库进行识别。
学习每一门语言都有一个“Hello World”程序,而对数字手写体数据库MNIST的识别就是深度学习的“Hello World”代码。下面我们给出详细的步骤。
tensorflow概述
tensorflow是用C++语言实现的一个深度学习模块。tensorflow是一种数据流编程,所谓的数据流编程,就是python编译环境只负责生成一个深度学习的数据流,然后将数据流传入C++语言的运行环境,在C++语言中执行,执行完的结果再返回到python运行环境,所以你会发现tensorflow的语言很复杂(深度学习本身就是一门较为前沿的领域),但是tensorflow的运行效率比较高,所以比较适合运行深度学习这样计算量庞大的算法。
下面我们来阐述一下手写字体的识别过程,其中手写字体的数据库下载在网上很容易找到:
MNIST数据集
MNIST数据集的官网是Yann LeCun's website。在这里,我们提供了一份python源代码用于自动下载和安装这个数据集。你可以用下面的代码导入到你的项目里面,也可以直接复制粘贴到你的代码文件里面。
下面我们给出具体的、完整的代码分析。
导入所需要的包
第一步,导入所需要的包:
其中,input_data这个类是tensorflow为数据集MNIST专门设计的,只针对MNIST数据集。
函数tf.reset_default_graph()是重置数据流。
导入数据
正如前文所介绍的,其中MNIST_data是文件夹的名字,文件夹需要与代码在同一个目录下,文件夹中的文件就是数据集,如图所示:
one_hot=True表示对标签进行独热编码。
构建网络
构建网络的输入输出,以及权重,如图所示:
其中x_data为输入的图像数据,y_data为输入的标签数据,w为网络的权重,bias为偏置,cross_e为基于交叉熵的损失函数,opt为梯度优化器,train为最终的训练接口。
训练
训练的代码也很简单,其中tf.train.Saver()是保存模型对象,后期可以继续调用。
测试
最后一步就是测试,从测试结果看,为0.91,其中x_test_data为测试的图像输入,y_test_data为测试图像的标签。
总结
本文用的虽然是tensorflow框架,处理的也是图片数据,但是并没有用到深度学习算法,只是用了普通的神经网络——全连接神经网络。所以效果为0.91,读者感兴趣可以使用CNN卷积神经网络对数据进行训练,可以得到较高的识别率。
谢谢阅读,希望对你的学习有所帮助。