一、断点续训
断点续训是指,如果之前训练过现在要做的模型,则可以通过读入之前训练过的模型及其参数,拿来训练现在要做的目标模型,这样可以节省训练时间,加快训练效率。
1.读取模型:
在之前已经保存过的模型中读取,被保存的模型名应该是ckpt为后缀的,index为后缀的是模型的索引:
load.weights(路径文件名)
代码为:
#判断如果目标文件存在,则直接读取模型
checkpoint_save_path = "./checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
print('-------------load the model-----------------')
model.load_weights(checkpoint_save_path)
2.保存模型:
将模型保存到指定路径:
tf.keras.callbacks.ModelCheckpoint(
filepath=路径文件名,
save_weights_only=True/False,
save_best_only=True/False,
history=model.fit(callbacks=[cp_callback]))
代码为:
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True)
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test)</