最近刚学深度学习模型,今天说一下一个经典的Keras实列:基于MNIST数据集的手写数字识别,也是深度学习过程中的一个入门级别的模型。
首先介绍一下minst数据集:MNIST数据集(Mixed National Institute of Standards and Technology database)是美国国家标准与技术研究院收集整理的大型手写数字数据库,包含60,000个示例的训练集以及10,000个示例的测试集。
大概图片如下:
下载数据集可以通过代码直接下载,也可以下载压缩包(导入解压过程比较麻烦),mnist数据集分为4类:训练数据集,训练数据集标签,测试数据集和测试数据集标签,
当我们执行代码的时候,会自己去下载mnist数据集(建议在早上下载),当导入数据集以后,我们开始初始化数据,然后进行模型的训练,最后用测试集进行测试,得出损失和准确率。
首先,便是各种包的导入:
然后,便是Mnist数据集的导入,使用mnist.load_data()可以自己去下载,接下来便初始化数据。
对于数据初始化以后,便设计自己的模型,我设置了2层卷积,2层池化,为了过拟合,设置Dropout,最后一层全连接输出。
然后将数据导入模型,进行训练
最后。用测试集进行测试,得到准确度。
下面附上所有代码链接: https://github.com/litongtong10067/CNN