1、绪论
Keras中的回调函数(Callback)是一种强大的机制,允许用户在模型训练、评估或预测的过程中插入自定义行为。这些回调函数提供了对训练循环各个阶段的访问,使得用户能够实时监控模型性能、调整训练参数、保存模型权重等操作。
例如,keras.callbacks.TensorBoard
回调可以帮助你使用TensorBoard来可视化训练过程中的数据,而 keras.callbacks.ModelCheckpoint
回调则允许你在训练过程中按照指定条件保存模型。
本文将深入探讨Keras回调函数的原理、用途以及如何实现自定义的回调函数。我们将通过几个实用的示例来展示如何使用这些回调函数,并指导程序员如何构建自己的回调函数以满足特定的需求。
通过实现keras.callbacks.Callback
基类或其子类,程序员可以定义在训练的不同阶段(如每个epoch的开始和结束、每个batch的开始和结束等)应该执行的操作。这些操作可能包括记录日志、调整学习率、早期停止训练等。掌握回调函数的使用,将使程序员能够更灵活地控制模型的训练过程,优化模型的性能。
接下来,我们将通过一系列实例来演示如何在实际应用中利用Keras回调函数,并详细解释如何构建自定义的回调函数。
2、自定义callback的方法
2.1 设置
import numpy as np
import keras
2.2 callback概览
在Keras中,所有的回调函数都继承自keras.callbacks.Callback
类,并覆盖了在训练、测试、预测过程中各个阶段被调用的方法集。回调函数在训练过程中非常有用,因为它们允许我们查看模型的内部状态和统计信息。
程序员可以将回调函数列表(作为关键字参数callbacks
)传递给以下模型方法:
keras.Model.fit()
:用于训练模型。keras.Model.evaluate()
:用于评估模型的性能。keras.Model.predict()
:用于对新的数据进行预测。
当调用这些方法时,Keras会在适当的阶段调用你提供的回调函数的相应方法。例如,on_epoch_begin
在每个epoch开始时被调用,on_epoch_end
在每个epoch结束时被调用,on_batch_begin
在每个batch开始时被调用,等等。
通过实现这些回调函数方法,程序员可以执行各种任务,如记录日志、可视化训练过程、调整学习率、在特定条件下保存模型等。这使得程序员能够灵活地定制模型的训练、评估和预测过程。
下面是一些常用的Keras回调函数示例:
keras.callbacks.ModelCheckpoint
:在训练过程中定期保存模型。keras.callbacks.TensorBoard
:使用TensorBoard可视化训练过程。keras.callbacks.ReduceLROnPlateau
:当学习停滞时减少学习率。keras.callbacks.EarlyStopping
:当性能不再提高时提前停止训练。
除了使用这些预定义的回调函数外,程序员还可以通过继承keras.callbacks.Callback
类并覆盖相应的方法来创建自定义的回调函数。这样,你可以根据具体需求定制模型的训练、评估和预测过程。
回调函数方法概述
在Keras中,回调函数提供了在模型训练、评估和预测过程中的不同阶段执行自定义操作的机制。这些操作通过覆盖keras.callbacks.Callback
基类中的方法来实现。以下是一些主要的回调方法及其概述:
2.2.1全局方法
on_train_begin(self, logs=None)
- 在
fit
方法开始时被调用。
on_train_end(self, logs=None)
- 在
fit
方法结束时被调用。
on_test_begin(self, logs=None)
- 在
evaluate
方法开始时被调用(用于测试集)。
on_test_end(self, logs=None)
- 在
evaluate
方法结束时被调用(用于测试集)。
on_predict_begin(self, logs=None)
- 在
predict
方法开始时被调用。
on_predict_end(self, logs=None)
- 在
predict
方法结束时被调用。
2.2.2批次级别方法(训练/测试/预测)
on_train_batch_begin(self, batch, logs=None)
- 在处理训练批次之前被调用。
on_train_batch_end(self, batch, logs=None)
- 在处理完训练批次后被调用。
logs
是一个字典,包含该批次的度量结果。
on_test_batch_begin(self, batch, logs=None)
- 在处理测试批次之前被调用(用于测试集)。
on_test_batch_end(self, batch, logs=None)
- 在处理完测试批次后被调用(用于测试集)。
on_predict_batch_begin(self, batch, logs=None)
- 在处理预测批次之前被调用。
on_predict_batch_end(self, batch, logs=None)
- 在处理完预测批次后被调用。
2.2.3周期级别方法(仅训练)
on_epoch_begin(self, epoch, logs=None)
- 在每个训练周期开始时被调用。
on_epoch_end(self, epoch, logs=None)
- 在每个训练周期结束时被调用。
logs
是一个字典,包含该周期的度量结果,如损失值、准确率等。
在构建自定义回调时,你可以根据需要覆盖这些方法。在方法内部,你可以访问模型的内部状态、修改训练参数、记录日志等。这使得回调函数成为定制模型训练过程的有力工具。
2.3 自定义callback
以下我们通过一个具体的例子来演示如何使用回调函数。首先,我们需要导入TensorFlow库并定义一个简单的Sequential Keras模型。
# Define the Keras model to add callbacks to
def get_model():
model = keras.Sequential()
model.add(keras.layers.Dense(1))
model.compile(
optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
loss="mean_squared_error",
metrics=["mean_absolute_error"],
)
return model
接下来,就可以使用Keras的数据集API来加载MNIST数据集用于训练和测试。
# Load example MNIST data and pre-process it
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
x_test = x_test.reshape(-1, 784).astype("float32") / 255.0
# Limit the data to 1000 samples
x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:1000]
y_test = y_test[:1000]
定义一个简单的自定义回调函数,用于记录以下事件:
- 当
fit
、evaluate
、predict
开始时和结束时 - 当每个周期开始时和结束时
- 当每个训练批次开始时和结束时
- 当每个评估(测试)批次开始时和结束时
- 当每个推理(预测)批次开始时和结束时
class CustomCallback(keras.callbacks.Callback):
def on_train_begin(self, logs=None):
keys = list(logs.keys())
print("Starting training; got log keys: {}".format(keys))
def on_train_end(self, logs=None):
keys = list(logs.keys())
print("Stop training; got log keys: {}".format(keys))
def on_epoch_begin(self, epoch, logs=None):
keys = list(logs.keys())
print("Start epoch {} of training; got log keys: {}".format(epoch, keys))
def on_epoch_end(self, epoch, logs=None):
keys = list(logs.keys())
print("End epoch {} of training; got log keys: {}".format(epoch, keys))
def on_test_begin(self, logs=None):
keys = list(logs.keys())
print("Start testing; got log keys: {}".format(keys))
def on_test_end(self, logs=None):
keys = list(logs.keys())
print("Stop testing; got log keys: {}".format(keys))
def on_predict_begin(self, logs=None):
keys = list(logs.keys())
print("Start predicting; got log keys: {}".format(keys))
def on_predict_end(self, logs=None):
keys = list(logs.keys())
print("Stop predicting; got log keys: {}".format(keys))
def on_train_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Training: start of batch {}; got log keys: {}".format(batch, keys))
def on_train_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Training: end of batch {}; got log keys: {}".format(batch, keys))
def on_test_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))
def on_test_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))
def on_predict_batch_begin(self, batch, logs=None):
keys = list(logs.keys())
print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))
def on_predict_batch_end(self, batch, logs=None):
keys = list(logs.keys())
print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))
定义完成后,就可以在脚本中进行使用
model = get_model()
model.fit(
x_train,
y_train,
batch_size=128,
epochs=1,
verbose=0,
validation_split=0.5,
callbacks=[CustomCallback()],
)
res = model.evaluate(
x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()]
)
res = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])
2.3.1 使用log字典
在Keras中,logs字典是在每个批次或周期结束时传递给回调函数的参数之一。它包含了在当前批次或周期计算出的损失值和所有指标值。这些值对于监控训练过程、调整学习率、进行早停(Early Stopping)等非常有用。
例如,如果你在模型编译时指定了损失函数和指标(如mean_absolute_error),那么这些值在每个批次或周期结束后都会计算并存储在logs字典中。
class LossAndErrorPrintingCallback(keras.callbacks.Callback):
def on_train_batch_end(self, batch, logs=None):
print(
"Up to batch {}, the average loss is {:7.2f}.".format(batch, logs["loss"])
)
def on_test_batch_end(self, batch, logs=None):
print(
"Up to batch {}, the average loss is {:7.2f}.".format(batch, logs["loss"])
)
def on_epoch_end(self, epoch, logs=None):
print(
"The average loss for epoch {} is {:7.2f} "
"and mean absolute error is {:7.2f}.".format(
epoch, logs["loss"], logs["mean_absolute_error"]
)
)
model = get_model()
model.fit(
x_train,
y_train,
batch_size=128,
epochs=2,
verbose=0,
callbacks=[LossAndErrorPrintingCallback()],
)
res = model.evaluate(
x_test,
y_test,
batch_size=128,
verbose=0,
callbacks=[LossAndErrorPrintingCallback()],
)
2.3.2 使用self.model属性
除了在其方法被调用时接收日志信息外,回调还具有访问与当前训练/评估/推理轮次关联的模型的权限:self.model
。
以下是程序员可以在回调中使用self.model
做的几件事:
- 设置
self.model.stop_training = True
以立即中断训练。 - 修改优化器的超参数(作为
self.model.optimizer
可用),比如self.model.optimizer.learning_rate
。 - 在指定的时间间隔内保存模型。
- 在每个周期结束时,记录模型在几个测试样本上的
model.predict()
输出,以在训练过程中进行合理性检查。 - 在每个周期结束时提取中间特征的可视化,以监控模型随时间的学习情况。
让我们通过几个例子来看看这些功能是如何工作的。
在最小损失时的提前停止
以下的例子展示了如何创建一个回调函数,该回调函数在达到最小损失时停止训练,通过设置属性self.model.stop_training
(布尔值)来实现。可选地,程序员可以提供一个patience
参数来指定在达到局部最小值后应该等待多少个周期再停止。
class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
"""Stop training when the loss is at its min, i.e. the loss stops decreasing.
Arguments:
patience: Number of epochs to wait after min has been hit. After this
number of no improvement, training stops.
"""
def __init__(self, patience=0):
super().__init__()
self.patience = patience
# best_weights to store the weights at which the minimum loss occurs.
self.best_weights = None
def on_train_begin(self, logs=None):
# The number of epoch it has waited when loss is no longer minimum.
self.wait = 0
# The epoch the training stops at.
self.stopped_epoch = 0
# Initialize the best as infinity.
self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
current = logs.get("loss")
if np.less(current, self.best):
self.best = current
self.wait = 0
# Record the best weights if current results is better (less).
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
print("Restoring model weights from the end of the best epoch.")
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print(f"Epoch {self.stopped_epoch + 1}: early stopping")
model = get_model()
model.fit(
x_train,
y_train,
batch_size=64,
epochs=30,
verbose=0,
callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()],
)
学习率调度
在这个例子中,我们将展示如何使用自定义回调在训练过程中动态更改优化器的学习率。
class CustomLearningRateScheduler(keras.callbacks.Callback):
"""Learning rate scheduler which sets the learning rate according to schedule.
Arguments:
schedule: a function that takes an epoch index
(integer, indexed from 0) and current learning rate
as inputs and returns a new learning rate as output (float).
"""
def __init__(self, schedule):
super().__init__()
self.schedule = schedule
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, "learning_rate"):
raise ValueError('Optimizer must have a "learning_rate" attribute.')
# Get the current learning rate from model's optimizer.
lr = self.model.optimizer.learning_rate
# Call schedule function to get the scheduled learning rate.
scheduled_lr = self.schedule(epoch, lr)
# Set the value back to the optimizer before this epoch starts
self.model.optimizer.learning_rate = scheduled_lr
print(f"\nEpoch {epoch}: Learning rate is {float(np.array(scheduled_lr))}.")
LR_SCHEDULE = [
# (epoch to start, learning rate) tuples
(3, 0.05),
(6, 0.01),
(9, 0.005),
(12, 0.001),
]
def lr_schedule(epoch, lr):
"""Helper function to retrieve the scheduled learning rate based on epoch."""
if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
return lr
for i in range(len(LR_SCHEDULE)):
if epoch == LR_SCHEDULE[i][0]:
return LR_SCHEDULE[i][1]
return lr
model = get_model()
model.fit(
x_train,
y_train,
batch_size=64,
epochs=15,
verbose=0,
callbacks=[
LossAndErrorPrintingCallback(),
CustomLearningRateScheduler(lr_schedule),
],
)
3、总结
关于Keras自定义回调(Callback)的总结,我们可以简要地归纳为以下几个方面:
3.1 定义与用途
- 回调函数是一组在训练的特定阶段被调用的函数集。
- 你可以使用回调函数来观察训练过程中网络内部的状态和统计信息。
- 通过传递回调函数列表到模型的
.fit()
中,即可在给定的训练阶段调用该函数集中的函数。
3.2 自定义Callback的基础
- Keras中的
keras.callbacks.Callback
是回调函数的抽象类,定义新的回调函数必须继承自该类。 - 类属性
params
是一个字典,包含训练参数集,如verbosity
(信息显示方法)、batch
大小、epoch
数等。 model
属性是keras.models.Model
对象,为正在训练的模型的引用。- 回调函数以字典
logs
为参数,该字典包含了一系列与当前batch
或epoch
相关的信息。
3.3 常见用途
- 早停(Early Stopping):当验证误差不再改善时,自动停止训练,防止过拟合。
- 学习率调度(Learning Rate Scheduling):在训练过程中动态改变学习率。
- 权重保存(Model Checkpointing):在每个epoch结束时保存模型权重,以便在后续可以加载最佳模型。
- TensorBoard日志记录:将训练过程中的统计信息写入TensorBoard日志文件,以便可视化。
- 自定义监控指标:在训练过程中记录或修改任何你感兴趣的指标。
3.4 实现自定义Callback
- 定义一个类,继承自
keras.callbacks.Callback
。 - 根据需要重写特定的方法,如
on_epoch_begin
、on_epoch_end
、on_batch_begin
、on_batch_end
、on_train_begin
、on_train_end
等。 - 在这些方法内部,你可以访问
self.model
和self.params
,以及传递进来的logs
字典。 - 使用
self.model
的属性或方法来操作模型,如改变学习率、保存模型等。
3.5 注意事项
- 自定义回调时要确保不要引入意外的副作用,特别是当在多个地方使用同一个模型或数据集时。
- 如果回调需要访问模型的状态(如权重),请确保在适当的时候(如
on_epoch_end
)进行访问,以避免在训练过程中造成不一致。 - 如果回调需要与其他回调或训练循环的其他部分进行交互,请确保遵循Keras的回调执行顺序和约定。