回调 call back

本文介绍了Keras中回调函数的使用,包括ModelCheckpoint用于模型保存,EarlyStopping实现早期停止,以及如何自定义回调以监控训练过程。通过设置callback参数,可以在训练过程中实现模型性能最佳时的保存,并防止过拟合。同时,展示了如何编写自定义回调以显示验证损失与训练损失的比例,以辅助检测过拟合。
摘要由CSDN通过智能技术生成

Hands-on Machine Learning with Scikit Learn, Keras & TensorFlow

Charpter 10 学习笔记

fit()方法接受一个回调参数,允许您指定Keras将在训练开始和结束时、在每个epoch开始和结束时、甚至在处理每个批处理之前和之后调用的对象列表。例如,ModelCheckpoint回调在训练期间定期保存模型的检查点,默认情况下保存在每个epoch的末尾:

[...]
checkpoint_cb = keras.callbacks.ModelCheckpoint("my_keras_model.h5")
history = model.fit(X_train, y_train, eporchs = 10, callbacks = [checkpoint_cb])

此外,如果您在培训期间使用验证集,您可以在创建ModelCheckpoint时设置save_best_only = True。在本例中,只有当模型在验证集上的性能达到最佳时,它才会保存模型。这样,您就不需要担心训练时间过长和训练集过拟合的问题:只需恢复训练后保存的最后一个模型,这将是验证集上的最佳模式。下面的代码是一种实现早期停止的简单方法:

checkpoint_cb = keras.callbacks.ModelCheckpoint("my_keras_model.h5")
history = model.fit(X_train, y_train, eporchs = 10, validation_data = (X_valid, y_valid), callbacks = [checkpoint_cb])
model = keras.models.load_model("my_keras_model.h5")
#roll back to best model

实现早期停止的另一种方法是简单地使用earlystop回调。当它在许多时期(由patience参数定义)没有测量到验证集上的进展时,它将中断训练,并可选地回滚到最佳模型。你可以结合两个回调来保存你的模型检查点(以防你的计算机崩溃)和中断训练早期,当没有更多的进展(以避免浪费时间和资源):

early_stopping_cb = keras.callbacks.EarlyStopping(patience = 10, restore_best_weights = True)
history = model.fit(X_train, y_train, eporchs = 100, validation_data = (X_valid, y_valid), callbacks = [checkpoint_cb, early_stopping_cb])

eporch的数量可以设置为一个较大的值,因为当没有更多的进展时,训练将自动停止。在这种情况下,不需要恢复保存的最佳模型,因为earlystop回调将跟踪最佳权重,并在训练结束时为您恢复它们。

如果您需要额外的控制,您可以轻松地编写自己的自定义回调。下面的自定义回调函数将显示在训练期间验证损失和训练损失之间的比率(例如,检测过拟合):

class PrintValTrainRatioCallback(keras.callbacks.Callback):
    def on_eporch_end(self, eporch, logs):
        print("\nval/train: {:.2f}".format(logs["val_loss"] / logs["loss"]))

正如你所期望的,你可以实现on_train_begin(), on_train_end(), on_epoch_begin(), on_epoch_end(), on_batch_begin()和on_batch_end()。如果需要的话,还可以在评估和预测期间使用回调(例如,用于调试)。对于求值,你应该实现on_test_begin(), on_test_end(), on_test_batch_begin(), 或on_test_batch_end()(由evaluate()调用),对于预测,你应该实现on_predict_begin(), on_predict_end(), on_predict_batch_begin(), 或on_predict_batch_end()(由predict()调用)。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值