大家好我是《高手杰瑞》,每天更新精彩教程,请关注我哦!
mnist数据集介绍
mnist数据是一个已经被用“烂”的一个数据集了,因为每一个深度学习入门教程中都会讲到mnist数据集,就连谷歌的tensorflow框架都内置mnist数据集的相关操作,以供初学者进行学习。
mnist数据集是一个手写数字数据集,由美国国家标准与技术研究所进行制作,mnist数据集里面分为训练集和测试集,训练集中的数字是由250个不同的人手写而成的,其中50%是高中生、50%来自人口普查局的工作人员,测试集也是以同样的比例来进行制作的。
CNN是什么?为什么用它?
CNN中文名称为卷积神经网络,它在图像识别任务中得到了特别广泛地使用。在结构上,我们会让一张输入的图片依次经过一系列的卷积层、非线性层、池化层最后与全连接层连接计算输出。简单点来讲,CNN能使计算机从图片中获得更多的特征信息。
为什么用它?
因为它能让我们识别的准确率提高,杰瑞层做过对比,仅使用一层全连接层的神经网络识别手写数字的准确率最高在94%左右,尽管数值看上去不低,但是在图像识别领域94%的准确率并不是很高,而在CNN模型中手写数字识别准确率最高达到了98%!尽管只有4%的差距,但在现实应用场景中,这4%已经是一个很大的提升了。
如何实现?
杰瑞接下来教大家如何去实现用卷积神经网络去识别手写数字,首先我们先把mnist数据集下载好,这里杰瑞教大家用tensorflow里面的方法来下载mnist数据集,先导入下面这个模块:
import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_data
然后使用read_data_sets()方法下载mnist数据集,这个方法第一个参数是我们存放数据集的地址,第二个是数据集的标签是否使用one-hot的形式,一般我们都是使用one-hot形式的标签,所以第二个参数填True.
mnist = input_data.read_data_sets("mnist/