学习 Keras 为主,内容比较简单
1 背景与目标
1.1 背景
Keras 是一个由纯 Python 编写而成的深度学习框架,并基于 Tensorflow、Theano 及 CNTK 后端。Keras 方便易用,能够迅速将 idea 转换为结果。
本文利用 Keras 构建多层感知器,进行0-9 的『手写数字识别』 ,旨在感受 Keras 库的友好与易用。
2.2 目标
使用 Keras 搭建深度学习模型,对 0-9 的手写数字图片进行有监督学习与识别。
数字取自开源的 mnist 数据库,总共包含:60 000 张训练集图片
10 000 张测试集图片
训练集与测试集数据均包含对应的标签。60 000 张训练集图片
2 数据预处理
mnist 数据库中的手写数字图片库是学习 Deep Learning 的基本素材,本项目虽然有图片格式的数据,但为方便操作,后续直接采用 .npz 格式的数据包。下载链接见此处。
3 代码实现
3.1 导入依赖库Sequential:Keras 中的序贯模型,其特点是多个网络层线性堆叠
Dense:全连接神经元层
Dropout:神经元输入的断接率
Activation:神经元层的激励函数
SGD:优化器(optimizers)中的随机梯度下降法
3.2 导入数据集
a) 处理数据集每一个手写数字图片为 28*28 像素的、经过灰度处理的图片,每个像素用 0-255 的 RGB 值表示
x_train/y_train:训练集数据及其对应标签,共 60 000 组
t_test/y_test:测试集数据及其标签,共 10 000 组
reshape:将 (60000, 28, 28) 的 3 维数组转化为 (60000, 28*28) 2维数组,变换后每一行代表一张图片,共 28*28=784 个点,每个点范围为 [0, 255]
astype('float32'):由于神经网络计算是需要用到大量线性运算,于是将数据设置为 32 位浮点数归一化数据集,将范围为 [0, 255] 的数据归一为 [0, 1]
b) 处理标签集
标签集中为范围 [0, 9] 的手写数字真实值,需要将其转化为二进制矩阵。Y_train:训练集标签,大小为 60000*10。60000 表示图片数量,10 表示数字值,如 (0, 0, 0, 1, 0, 0, 0, 0, 0, 0) 中第四位为 1,表示图片真实值为 (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)中的第四位,即手写数字为 3
Y_train:同上,共 10000 组图片
3.3 建立神经网络模型Sequential():选用序贯模型进行训练及识别
batch_size = 128:每次梯度下降运算时包含的数据数量为 128
epochs = 20:迭代 20 次
Dense(500):每层 500 个神经元,输入层、隐藏层、输出层的连接为 Dense 全连接方式。
input_shape=(28*28,):输入为 784 维向量。
Dense(10):输出 10 维向量
relu:隐藏层的激活函数
softmax:输出层的激活函数
Relu 激活函数,其更易于学习优化:
用 summary 进行总结,观察神经网络模型,该模型一共 898510 个参数:
3.4 编译模型sgd:随机梯度下降法
lr=0.01:学习率
decay=1e-6:每次更新后的学习率衰减值
momentum=0.9:动量参数
nesterov=True:确定是否使用Nesterov动量
optimizer = sgd:优化器选为随机梯度下降法
loss = 'categorical_crossentropy':损失函数选为多类的对数损失,适用于二值序列
metrics = ['accuracy']:指标列表,用于性能评估,一般设置为 metrics=['accuracy']
3.5 训练模型verbose = 1:显示训练日志
X_train, Y_train:训练数据集
validation_data = (X_test, Y_test):测试数据集
loss/acc:训练集的损失值与准确值
val_loss/val_acc:测试集的损失值与准确值
3.6 可视化结果
a) 将训练结果转化为 DataFrame
b) 可视化结果上:训练集的准确率与损失值
下:测试集的准确率与损失值
4 性能评价
从上表与上图中可以发现:训练集的准确率随迭代次数的增加而上升,在第 7 次后,准确率已经达到近 98%
经过 20 次迭代,训练集的准确率达到 99.37%,已经非常完美了
事实上上表已经包含测试集应用模型的结果,但按照国际惯例,我们还是要用 evaluate 方法对模型进行评估:
可以发现,经过训练后的模型:准确率为:98.25%
说明本项目建立的神经网络模型在手写数字的识别上,具有 98.25% 的识别准确率。
5 总结反思
to be continued……