tf.keras.callbacks.ModelCheckpoint(
filepath, monitor='val_loss', verbose=0, save_best_only=False,
save_weights_only=False, mode='auto', save_freq='epoch',
options=None, **kwargs
)
参数:
- filepath:字符串,保存模型的路径。
- monitor:需要监视的评估指标名。
- verbose:详细信息模式,0或1。
- save_best_only:如果save_best_only=True,被监测数据的最佳模型就不会被覆盖。
- mode: {auto, min, max} 的其中之一。如果 save_best_only=True,那么是否覆盖保存文件的决定就取决于被监测数据的最大或者最小值。 对于 val_acc,模式就会是 max,而对于 val_loss,模式就需要是 min,等等。在auto模式下,方向由被监测值的名字自动推断。
- save_weights_only:如果为True,则仅保存模型的权重(model.save_weights(filepath)),否则保存完整模型(model.save(filepath))。
- period:CheckPoint之间的间隔(epoch数)。
filepath 可以包括命名格式选项,可以由 epoch的值和 logs的键(由 on_epoch_end 参数传递)来填充。
例如:如果 filepath 是 weights.{epoch:02d}-{val_loss:.2f}.hdf5,
那么模型被保存的的文件名就会有训练轮数和验证损失。
PS:
ModelCheckpoint回调与使用model.fit()进行的训练结合使用,可以以一定间隔保存模型或权重(在检查点文件中),因此可以稍后加载模型或权重以从保存的状态继续训练。
示例
model.compile(loss=..., optimizer=...,
metrics=['accuracy'])
EPOCHS = 10
checkpoint_filepath = '/tmp/checkpoint'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='val_accuracy',
mode='max',
save_best_only=True)
# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
# The model weights (that are considered the best) are loaded into the model.
model.load_weights(checkpoint_filepath)
or
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
save_best_only=False,
verbose=1,
save_freq='epoch')
self.gpt2_model.compile(
optimizer=self.optimizer,
loss=self.loss_object,
metrics=[self.sparse_categorical_accuracy])
self.gpt2_model.fit(
x={
"input_ids": train_sample['input_ids'],
"attention_mask": train_sample['attention_mask']
},
y=train_sample['label_ids'],
batch_size=self.batch_size,
epochs=self.epochs,
shuffle=True,
validation_data=(
{
"input_ids": val_sample['input_ids'],
"attention_mask": val_sample['attention_mask']
},
val_sample['label_ids']
),
callbacks=[self.tensorboard_callback, model_checkpoint_callback]
)