1、数据准备
# 加载MNIST数据集
data= np.load('/Users/code/MNIST_data/mnist.npz', allow_pickle=True)
x_train, y_train = data['x_train'], data['y_train']
x_test, y_test = data['x_test'], data['y_test']
# 归一化像素值到0~1之间
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
#随机排序
perm = np.random.permutation(x_train.shape[0])
x_train = x_train[perm]
y_train = y_train[perm]
# 对label进行one-hot编码
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)
2、使用keras.models.load_model加载模型
model_path='/Users/code/model/mnist.h5'
new_model = models.load_model(model_path)
3、编译模型
#optimizer优化器
#loss损失函数
#metrics准确率
new_model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
4、模型训练
#train_images, train_labels输入训练集的特征和标签
#validation_data导入测试集的特征和标签
#epochs要迭代多少次数据集
#batch_size每批数据多少个
#verbose打印训练过程,verbose=1会显示进度条,verbose=2会显示每轮次的训练损失和验证损失
#callbacks=[callback]回调函数,非必需
new_model.fit(x_train, y_train,
validation_data=(x_test, y_test),
epochs=5,
batch_size=128,verbose=1)
5、打印模型
new_model.summary()
结果输出:
Epoch 1/5
469/469 [==============================] - 52s 109ms/step - loss: 0.0211 - accuracy: 0.9930 - val_loss: 0.0271 - val_accuracy: 0.9913
Epoch 2/5
469/469 [==============================] - 49s 105ms/step - loss: 0.0145 - accuracy: 0.9954 - val_loss: 0.0274 - val_accuracy: 0.9910
Epoch 3/5
469/469 [==============================] - 52s 110ms/step - loss: 0.0111 - accuracy: 0.9964 - val_loss: 0.0268 - val_accuracy: 0.9917
Epoch 4/5
469/469 [==============================] - 53s 112ms/step - loss: 0.0089 - accuracy: 0.9970 - val_loss: 0.0276 - val_accuracy: 0.9924
Epoch 5/5
469/469 [==============================] - 52s 111ms/step - loss: 0.0080 - accuracy: 0.9973 - val_loss: 0.0225 - val_accuracy: 0.9940
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 28, 28, 32) 320
max_pooling2d (MaxPooling2D (None, 14, 14, 32) 0
)
conv2d_1 (Conv2D) (None, 12, 12, 64) 18496
max_pooling2d_1 (MaxPooling (None, 6, 6, 64) 0
2D)
conv2d_2 (Conv2D) (None, 4, 4, 128) 73856
flatten (Flatten) (None, 2048) 0
dense (Dense) (None, 128) 262272
dense_1 (Dense) (None, 10) 1290
=================================================================
Total params: 356,234
Trainable params: 356,234