1. 什么是断点续训问题
你可以这样想,当你训练模型需要很多epoch,但是你具体有多少个epoch才能达到你的标准,你把epoch设置高了不仅会加大资源的消耗而且很容易出现其他的一些问题。所以有人就想了,能不能把模型分多次跑,每一次在上一次的基础上继续训练,直到达到我们满意的效果,如果可以这样就存在一个问题,就是如何保存每一次的训练参数,这就是断点续训问题。
2. 如何执行断点续训
2.1 tf.keras.callbacks.ModelCheckpoint()类
我们可以使用tensorflow提供的类tf.keras.callbacks.ModelCheckpoint()
tf.keras.callbacks.ModelCheckpoint(
filepath,
monitor='val_loss',
verbose=0,
save_best_only=False,
save_weights_only=False,
mode='auto',
save_freq='epoch',
**kwargs,
)
Docstring:
Save the model after every epoch. # 每一个epoch之后保存参数,可以使用参数save_best_only=True来只保存最好的参数
`filepath` can contain named formatting options,
which will be filled the value of `epoch` and
keys in `logs` (passed in `on_epoch_end`).
For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`,
then the model checkpoints will be saved with the epoch number and
the validation loss in the filename.
经常用的几个参数:
filepath:保存模型文件的路径
save_best_only:只保存最好模型时的参数(True)
save_weights_only:如果为True,只保存模型参数信息(model.save_weights(filepath)); 如果为False:则保存整个模型(model.save(filepath)
save_freq:'epoch'或者是一个integer.当使用'epoch'时,模型会在每个epoch之后保存代码,当使用integer时,模型会在你指定的几次epoch之后开始保存模型。如果你使用了save_best_only,则这个参数不需要设置
2.2 开始断点续训
checkpoint_save_path = "./checkpoint/mnist.ckpt" # 声明一个ckpt文件存储路径
if os.path.exists(checkpoint_save_path + '.index'): # 这里判断是否之前已经存储了模型的训练信息,如果是,则为模型加载之前的参数
print('-------------load the model-----------------')
model.load_weights(checkpoint_save_path)
# save model after every epoch
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True)
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
callbacks=[cp_callback])
model.summary()
第一次运行代码
第二次运行代码
观察准确率可以看到,第二次运行代码是在第一次的基础上运行的