这里我们使用TensorFlow的关于MNIST的数据集的前1000张图片来进行模型的训练和测试.
一.准备
1.1 得到数据集
下载数据集的代码:且我们只取得前面1000个样本.并且都除以255进行归一化处理.
from __future__ import absolute_import,division,print_function
import os
import tensorflow as tf
from tensorflow import keras
(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.mnist.load_data()
train_labels=train_labels[:1000]
test_labels=test_labels[:1000]
print(train_images.shape)
print(test_images.shape)
train_images=train_images[:1000].reshape(-1,28*28)/255.0
test_images=test_images[:1000].reshape(-1,28*28)/255.0
print(test_images.shape)
print(test_images.shape)
结果:
11493376/11490434 [==============================] - 14s 1us/step
(60000, 28, 28)
(10000, 28, 28)
(1000, 784)
(1000, 784)
1.2 定义一个模型
我们这里训练的模型只有三层,最后一层通过softmax输出对于每一个样本预测的概率值.
def create_model():
model=tf.keras.models.Sequential([
keras.layers.Dense(512,activation=tf.nn.relu,input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10,activation=tf.nn.softmax)
])
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
return model
model=create_model()
model.summary()
结果:这里的第一层是全连接层,所以参数的个数是512*(784+1)=401920,第二层的参数是0个,...
Layer (type) Output Shape Param #
=================================================================
dense_1 (Dense) (None, 512) 401920
_________________________________________________________________
dropout_1 (Dropout) (None, 512) 0
_________________________________________________________________
dense_2 (Dense) (None, 10) 5130
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________
二.在训练过程中保存检查点
使用tf.keras.callbacks.ModelCheckPoint这个回调函数来实现检查点机制,需要配置以下的参数.
2.1 检查点回调函数的使用
先训练模型,然后将其传给ModelCheckpoint回调函数:
在这里直接运行代码的时候会出现错误:ImportError: `save_weights` requires h5py.所以我先使用pip install h5py,但是提示我先安装cython,所以我先pip install cython.发现可以成功的import h5py.
checkpoint_path='./cp.ckpt'
checkpoint_dir=os.path.dirname(checkpoint_path)
cp_callback=tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
save_weights_only=True,
verbose=1)
model=create_model()
model.fit(train_images,train_labels,epochs=10,
validation_data=(test_images,test_labels),
callbacks=[cp_callback])
结果展示:
Epoch 8/10
32/1000 [..............................] - ETA: 0s - loss: 0.1069 - acc: 1.0000
160/1000 [===>..........................] - ETA: 0s - loss: 0.0643 - acc: 1.0000
288/1000 [=======>......................] - ETA: 0s - loss: 0.0758 - acc: 0.9861
384/1000 [==========>...................] - ETA: 0s - loss: 0.0713 - acc: 0.9896
512/1000 [==============>...............] - ETA: 0s - loss: 0.0659 - acc: 0.9922
640/1000 [===&