本文涉及到的是中国大学慕课《人工智能实践:Tensorflow笔记》第四讲第四节的内容,通过tensorflow实现神经网络模型的断点续训。
断点续训
神经网络模型的断点续训指的是将训练好的模型保存下来,并在之后的运行中直接调用的训练方法。
保存模型使用以下代码,将模型参数保存到callbacks中,并将callbacks添加到模型训练的history中。
callbacks = tf.keras.callbacks.ModelCheckpoint( (
filepath= = 路径文件名,
save_weights_only= = True/False,
save_best_only= = True/False) )
history = model.fit( ( callbacks=[cp_callback] )
读取模型使用以下代码,基本原理是生成.ckpt文件的同时会生成对应的索引表(.index文件),通过判断索引表是否存在来决定是否导入保存的模型。
checkpoint_save_path = "./checkpoint/mnist.ckpt" # 给出模型保存的路径以及文件名
if os.path.exists(checkpoint_save_path + '.index'): # 通过索引表判断保存的模型是否存在
print('-------------load the model-----------------') # 是,则打印"导入模型"
model.load_weights(checkpoint_save_path) # 导入模型
代码实现
从神经网络搭建的六步法来看,与DL with python(6)——Keras实现手写数字识别(全连接网络)中直接导入mnist数据的代码相比,断点续训的代码在第一步、第四步和第五步有所改动。
# 第一步,导入相关模块,os模块用于判断文件是否存在
import tensorflow as tf
import os
# 第二步,导入数据集
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# 第三步,搭建网络结构
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
# 第四步,配置训练方法
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
# 导入保存的模型,第二次运行才可以进行的操作
checkpoint_save_path = "./checkpoint/mnist.ckpt" # 给出模型保存的路径以及文件名
if os.path.exists(checkpoint_save_path + '.index'): # 通过索引表判断保存文件是否存在
print('-------------load the model-----------------') # 是,则打印"导入模型"
model.load_weights(checkpoint_save_path) # 导入模型
# 保存模型,第一次运行执行这一步操作
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, # 模型保存路径
save_weights_only=True, # 只保留模型参数
save_best_only=True) # 只保留最优结果
# 第五步,执行训练,依次为训练集样本,训练集标签,小批量大小32,训练轮次5,测试集,训练集循环1轮次进行一次测试
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
callbacks=[cp_callback]) # 最后添加callbacks回调选项,将前面保存的模型参数赋给后面的模型
# 第六步,打印网络结构和参数统计
model.summary()
第一次运行代码后,会得到一个checkpoint文件夹,其中含有四个文件,含有模型的相关信息。
第一次运行,模型的最终表现
60000/60000 [==============================] - 5s 77us/sample - loss: 0.0454 - sparse_categorical_accuracy: 0.9863 - val_loss: 0.0865 - val_sparse_categorical_accuracy: 0.9752
然后第二次运行,在第一次的基础上进行训练,最终表现如下,可以看到各方面指标都有了很大的进步。
60000/60000 [==============================] - 4s 69us/sample - loss: 0.0180 - sparse_categorical_accuracy: 0.9943 - val_loss: 0.0785 - val_sparse_categorical_accuracy: 0.9782