Keras — mnist 手写体数字识别
一、前序知识
-
MNIST:大多数示例使用手写数字的MNIST数据集。
该数据集包含60,000个用于训练的示例和10,000个用于测试的示例。这些数字已经过尺寸标准化并位于图像中心。
MNIST图像数据集使用 [28,28] 的二维数组来表示每个手写体数字,数组中的每一个元素对应于一个像素点,即每张图片固定为 28 * 28 像素大小。
数据集官网:http://yann.lecun.com/exdb/mnist/
-
Kears
关于Keras:
Keras是由纯python编写的基于theano/tensorflow的深度学习框架。
Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结果,如果有如下需求,可以优先选择Keras:- 简易和快速的原型设计(keras具有高度模块化,极简,和可扩充特性)
- 支持CNN和RNN,或二者的结合
- 无缝CPU和GPU切换
二、Keras 训练 mnist
- 使用 Keras 加载 MNIST 数据集:
# 导入手写体的数据集
from keras.datasets import mnist
dataset_name = 'mnist.npz'
data = mnist.load_data(dataset_name)
(x_train, y_train), (x_test, y_test) = data
# reshape(samples, feature)
x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
- 数据预处理:
# 导入one-hot,标签转化为向量
from keras.utils import to_categorical
# 将图片灰度值归一化到0~1之间的值
x_train /= 255
x_test /= 255
# 将标签数据进行 one-hot编码
# 分类类别
num_category = 10
# one-hot, 将类别标签转化为二进制的向量,第i位为1,就表示为第i类
y_train = to_categorical(y_train, num_category)
y_test = to_categorical(y_test, num_category)
- 使用 Keras 定义 MNIST 网络模型:
input_shape=(784,):输入层 784 为样本的特征数28 * 28
输出层神经元个数为10,为分类的类别数
activation=‘softmax’:多分类问题,输出层激活函数使用 softmax
# 导入相关层的结构
from keras.models import Sequential
from keras.layers import Dense
# Sequential 顺序模型,它由多个网络层线性堆叠
model = Sequential()
# 使用 .add() 来堆叠模型:
model.add(Dense(512, activation='relu', input_shape=(784,), name='Dense_0'))
model.add(Dense(512, activation='relu', name='Dense_1'))
model.add(Dense(10, activation='softmax', name='Dense_2'))
- 可视化网络:
# 可视化神经网络
from keras.utils import plot_model
print(model.summary())
plot_model(model, to_file='mnist.png')
模型信息:
- model.summary():
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
Dense_0 (Dense) (None, 512) 401920
_________________________________________________________________
Dense_1 (Dense) (None, 512) 262656
_________________________________________________________________
Dense_2 (Dense) (None, 10) 5130
=================================================================
Total params: 669,706
Trainable params: 669,706
Non-trainable params: 0
_________________________________________________________________
- mnist.png:
- 使用 Keras 训练 MNIST 网络模型:
metrics=[‘accuracy’]:Keras自带的统计准确率
# 导入keras的优化器
from keras.optimizers import RMSprop
# 使用 .compile() 来配置学习过程:
model.compile(loss='categorical_crossentropy',
optimizer=RMSprop(),
metrics=['accuracy'])
# 进行训练
history = model.fit(x_train, y_train,
batch_size=128, epochs=10,
verbose=1, validation_data=(x_test, y_test))
# 评估模型性能
score = model.evaluate(x_test, y_test, verbose=0)
print(score)
- 保存训练好的 MNIST 模型:
# 保存训练好的模型
model.save('mnist.h5')
三、完整代码
完整代码已上传至github:
https://github.com/pentiumCM/machinelearn/blob/master/machinelearn/keras_learn/hello_mnist.py