前言
- 利用搭建网络八股,使用简单的bp神经网络完成手写数字的识别。
搭建过程
- 导入相应的包
- 获取数据集,划分数据集和测试集并进行简单处理(归一化等)
- 对数据进行乱序处理
- 定义网络结构
- 选择网络优化器以及损失函数
- 进行训练,设置batch大小以及迭代次数
代码实现
import tensorflow as tf
import numpy as np
# 导入手写数字数据集
from tensorflow.keras.datasets import mnist
# 获取数据
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 对数据归一化
x_train = x_train / 255.0
x_test = x_test / 255.0
# 打乱训练集顺序
# 设置随机种子,使训练集和标签打乱顺序的次序一样
np.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
# 搭建网络
model = tf.keras.models.Sequential([
# 将照片拉直成1维数组
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu', kernel_regularizer=tf.keras.regularizers.l2()),
# 10个数字,最后一层设置为10个神经元
tf.keras.layers.Dense(10, activation='softmax')]
)
# 设置训练参数
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.01), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
# 训练模型
model.fit(x_train, y_train, batch_size=400, epochs=10, validation_data=(x_test, y_test), validation_freq=10)
# 打印网络结构
model.summary()
到这完成了手写数字识别体训练,有什么问题,大家可以留言。