深度学习 - 5.TF x Keras 编写回调函数

36 篇文章 6 订阅
24 篇文章 11 订阅

一.回调函数概述

1.回调函数的功能

回调是一种强大的工具,可以在训练,评估或推理期间自定义Keras模型的行为。我们可以在不同时期实现不同功能的回调函数,例如每次训练的开始或结束,每轮epoch的开始或结束,每一批batch的开始或结束等,并实现以下行为:

->在训练过程中的不同时间点进行验证(除了内置的按时间段验证)
->定期或在超过特定精度阈值时对模型进行检查
->当训练似乎停滞不前时,更改模型的学习率
->当训练似乎停滞不前时,对顶层进行微调
->在训练结束或超出特定性能阈值时发送电子邮件或即时消息通知

2.回调函数的使用

可以将回调列表(作为关键字参数callbacks )传递给以下模型方法:
keras.Model.fit()
keras.Model.evaluate()
keras.Model.predict()

二.回调函数实现方式

回调函数支持调用官方回调函数如早停(early stop),定期存储ckpt(checkPoint),动态学习率(LearningRateScheduler)等,也可以自己继承keras.callbacks.Callback,并在相对应的时期实现自己的功能。

1.官方支持的回调函数

下面列出了官方 API 中给出的回调函数,后续也会给出一些常用回调函数的使用方法:早停,定期存储,学习率调整等。注意这里并不限定每次只使用一个回调函数,也可以在 model fit / evaluate / predict 过程中传入回调函数列表,这样可以同时使用多个回调函数。

class BaseLogger:累积指标的历时平均值的回调。

class CSVLogger:将纪元结果流式传输到CSV文件的回调。

class Callback:用于建立新回调的抽象基类。

class CallbackList:容器抽象一个回调列表。

class EarlyStopping:在受监控的指标停止改进时停止训练。

class History:将事件记录到History对象中的回调。

class LambdaCallback:用于即时创建简单的自定义回调的回调。

class LearningRateScheduler:学习率调度程序。

class ModelCheckpoint:回调以某种频率保存Keras模型或模型权重。

class ProgbarLogger:将指标输出到stdout的回调。

class ReduceLROnPlateau:当指标停止改善时,降低学习率。

class RemoteMonitor:用于将事件流传输到服务器的回调。

class TensorBoard:为TensorBoard启用可视化。

class TerminateOnNaN:在遇到NaN丢失时终止训练的回调。

2.自定义回调函数方法

也可以像之前自定义损失,自定义指标一样,自定义回调函数,这里需要继承 Callback。下面展示的是所有方法的demo用法,可以在训练开始、结束时,每个epoch开始、结束时等等等等实现自己的回调。实际使用当中,并不需要实现全部方法,只需要根据自己的需求实现对应方法,例如我只想在模型开始训练前,输出 "hello world",则我只需要写 on_train_begin 一个方法就行。其次,logs内的内容也是根据自己模型返回的字典来决定的,例如编译模型时我传入了 metrics=["mean_absolute_error"],则在实现 CallBack 时我就可以从 logs 里得到 mean_absolute_error 的值,随后可以根据自己的需求进行打印输出等操作。

class CustomCallback(keras.callbacks.Callback):
    # 定义时需要注意 这里一些logs可能是空 所以需要注意拿的元素是否为None
    # example:
    # Training: end of batch 0; got log keys: ['batch', 'size', 'loss', 'mean_absolute_error']
    def on_train_begin(self, logs=None):
        keys = list(logs.keys())
        print("Starting training; got log keys: {}".format(keys))

    def on_train_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop training; got log keys: {}".format(keys))

    def on_epoch_begin(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_test_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start testing; got log keys: {}".format(keys))

    def on_test_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop testing; got log keys: {}".format(keys))

    def on_predict_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start predicting; got log keys: {}".format(keys))

    def on_predict_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop predicting; got log keys: {}".format(keys))

    def on_train_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: start of batch {}; got log keys: {}".format(batch, keys))

    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: end of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))

三.常用回调函数实现

先铺一些所有样例都用的代码,主要是编译的模型和手写识别训练数据。

def get_model():
    model = keras.Sequential()
    model.add(keras.layers.Dense(1, input_dim=784))
    model.compile(
        optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
        loss="mean_squared_error",
        metrics=["mean_absolute_error"],
    )
    return model

# Load example MNIST data and pre-process it
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype("float32") / 255.0
x_test = x_test.reshape(-1, 784).astype("float32") / 255.0

# Limit the data to 1000 samples
x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:1000]
y_test = y_test[:1000]

1.定点输出自定义日志

这里在每个batch后输出了loss值,并在每一个epoch后输出了loss和mae,也可以自己编译时加入更多参数,然后在实现回调时从logs的字典里获取相关内容。

class LossAndErrorPrintingCallback(keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        print("For batch {}, loss is {:7.2f}.".format(batch, logs["loss"]))

    def on_test_batch_end(self, batch, logs=None):
        print("For batch {}, loss is {:7.2f}.".format(batch, logs["loss"]))

    def on_epoch_end(self, epoch, logs=None):
        print(
            "The average loss for epoch {} is {:7.2f} "
            "and mean absolute error is {:7.2f}.".format(
                epoch, logs["loss"], logs["mean_absolute_error"]
            )
        )

model = get_model()
model.fit(
    x_train,
    y_train,
    batch_size=128,
    epochs=2,
    verbose=0,
    callbacks=[LossAndErrorPrintingCallback()],
)

回调结果: 

For batch 0, loss is   28.13.
For batch 1, loss is  895.41.
For batch 2, loss is   29.60.
For batch 3, loss is   11.48.
For batch 4, loss is    7.85.
For batch 5, loss is    6.86.
For batch 6, loss is    6.40.
For batch 7, loss is    5.11.
The average loss for epoch 0 is  126.70 and mean absolute error is    6.15.
For batch 0, loss is    4.54.
For batch 1, loss is    5.71.
For batch 2, loss is    4.40.
For batch 3, loss is    4.75.
For batch 4, loss is    4.39.
For batch 5, loss is    4.61.
For batch 6, loss is    4.66.
For batch 7, loss is    4.89.
The average loss for epoch 1 is    4.74 and mean absolute error is    1.75.

2.Early Stop

模型训练过拟合的风险会随着轮数的增加而增加,所以可以在模型达到一定精度要求,或者模型几次训练相对应指标都无明显变化时,停止模型的训练。keras中,设置self.model.stop_training = True可以立即中断训练。

class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
    """Stop training when the loss is at its min, i.e. the loss stops decreasing.

  Arguments:
      patience: Number of epochs to wait after min has been hit. After this
      number of no improvement, training stops.
  """

    def __init__(self, patience=0):
        super(EarlyStoppingAtMinLoss, self).__init__()
        # 耐心度与最佳权重保存
        self.patience = patience
        # best_weights to store the weights at which the minimum loss occurs.
        self.best_weights = None

    # 训练开始时初始化最佳间距与容忍度
    def on_train_begin(self, logs=None):
        # The number of epoch it has waited when loss is no longer minimum.
        self.wait = 0
        # The epoch the training stops at.
        self.stopped_epoch = 0
        # Initialize the best as infinity.
        self.best = np.Inf

    def on_epoch_end(self, epoch, logs=None):
        # 获取当前loss值 如果小于best 则更新best为当前currebt 并存储当前最佳模型参数
        current = logs.get("loss")
        if np.less(current, self.best):
            self.best = current
            self.wait = 0
            # Record the best weights if current results is better (less).
            self.best_weights = self.model.get_weights()
        else:
            # 如果一次训练 loss未减少 则减少一次容忍度
            self.wait += 1
            if self.wait >= self.patience:
                # 如果超过容忍度 则记录当前epoch轮数 停止模型训练 获取最小loss时的weights
                self.stopped_epoch = epoch
                self.model.stop_training = True
                print("Restoring model weights from the end of the best epoch.")
                self.model.set_weights(self.best_weights)

    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0:
            print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))

model = get_model()
model.fit(
    x_train,
    y_train,
    batch_size=64,
    steps_per_epoch=5,
    epochs=30,
    verbose=0,
    # 可以自定义 patience
    callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss(2)],
)

EarlyStoppingAtMinLoss 类会在每一个epoch后检查模型loss值的变化并更新记录全局最优loss的值和参数,这其中添加了容忍度的参数,如果超过了容忍度的限制,则停止模型训练并获取最优loss时模型的训练参数当做本次模型训练的结果。这里我们还通过list传入了上面定点输出日志的callback,看一下两个callback同时执行的效果。因为我们的容忍度为2,第一次loss从2.06增加到2.27,容忍度 += 1,第二次loss从2.27增加到4.59,容忍度 += 1达到我们的预定要求,所以训练提前结束。

For batch 0, loss is   29.23.
For batch 1, loss is  872.63.
For batch 2, loss is   21.21.
For batch 3, loss is    9.73.
For batch 4, loss is    7.15.
The average loss for epoch 0 is  187.99 and mean absolute error is    8.19.
For batch 0, loss is    7.31.
For batch 1, loss is    6.15.
For batch 2, loss is    6.67.
For batch 3, loss is    4.55.
For batch 4, loss is    6.30.
The average loss for epoch 1 is    6.20 and mean absolute error is    2.06.
For batch 0, loss is    4.98.
For batch 1, loss is    5.11.
For batch 2, loss is    8.20.
For batch 3, loss is    8.60.
For batch 4, loss is   11.30.
The average loss for epoch 2 is    7.64 and mean absolute error is    2.27.
For batch 0, loss is   15.43.
For batch 1, loss is   17.26.
For batch 2, loss is   16.51.
For batch 3, loss is   27.80.
For batch 4, loss is   64.24.
The average loss for epoch 3 is   29.29 and mean absolute error is    4.59.
Restoring model weights from the end of the best epoch.
Epoch 00004: early stopping

3.学习率调整

不同的训练时期,模型参数更新的步长并不完全一致,这里通过简单的分段函数,对模型的学习率进行调整,已便在后期可以缩小步长,更快的迭代到最优参数。

def learningRate():
    # 在训练过程中使用自定义的回调来动态更改优化器的学习率。
    class CustomLearningRateScheduler(keras.callbacks.Callback):
        """Learning rate scheduler which sets the learning rate according to schedule.

      Arguments:
          schedule: a function that takes an epoch index
              (integer, indexed from 0) and current learning rate
              as inputs and returns a new learning rate as output (float).
      """

        def __init__(self, schedule):
            super(CustomLearningRateScheduler, self).__init__()
            self.schedule = schedule

        def on_epoch_begin(self, epoch, logs=None):
            # hasattr object内是否包含变量x
            if not hasattr(self.model.optimizer, "lr"):
                raise ValueError('Optimizer must have a "lr" attribute.')
            # Get the current learning rate from model's optimizer.
            # 模型必须有学习率的参数 通过getValue(获取学习率 model.optimizer.learning_rate)
            lr = float(tf.keras.backend.get_value(self.model.optimizer.learning_rate))
            # Call schedule function to get the scheduled learning rate.
            # 通过分段函数动态修改模型的学习率
            scheduled_lr = self.schedule(epoch, lr)
            # Set the value back to the optimizer before this epoch starts
            # backend 即为反向传播 为反向传播设置学习率 为函数修正后的学习率
            tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
            print("\nEpoch %05d: Learning rate is %6.4f." % (epoch, scheduled_lr))

    LR_SCHEDULE = [
        # (epoch to start, learning rate) tuples
        (3, 0.05),
        (6, 0.01),
        (9, 0.005),
        (12, 0.001),
    ]

    def lr_schedule(epoch, lr):
        """Helper function to retrieve the scheduled learning rate based on epoch."""
        if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
            return lr
        for i in range(len(LR_SCHEDULE)):
            if epoch == LR_SCHEDULE[i][0]:
                return LR_SCHEDULE[i][1]
        return lr

model = get_model()
model.fit(
    x_train,
    y_train,
    batch_size=64,
    steps_per_epoch=5,
    epochs=15,
    verbose=0,
    callbacks=[
        LossAndErrorPrintingCallback(),
        CustomLearningRateScheduler(lr_schedule),
    ],
)

loss的输出日志这里就不赘述了,看一下学习率参数的调整:

Epoch 00000: Learning rate is 0.1000.
The average loss for epoch 0 is  182.18 and mean absolute error is    8.15.

Epoch 00001: Learning rate is 0.1000.
The average loss for epoch 1 is    6.83 and mean absolute error is    2.11.

Epoch 00002: Learning rate is 0.1000.
The average loss for epoch 2 is    9.19 and mean absolute error is    2.45.

Epoch 00003: Learning rate is 0.0500.
The average loss for epoch 3 is    4.64 and mean absolute error is    1.69.

Epoch 00004: Learning rate is 0.0500.
The average loss for epoch 4 is    3.65 and mean absolute error is    1.55.

Epoch 00005: Learning rate is 0.0500.
The average loss for epoch 5 is    3.98 and mean absolute error is    1.57.

Epoch 00006: Learning rate is 0.0100.
The average loss for epoch 6 is    3.36 and mean absolute error is    1.44.

Epoch 00007: Learning rate is 0.0100.
The average loss for epoch 7 is    4.17 and mean absolute error is    1.62.

Epoch 00008: Learning rate is 0.0100.
The average loss for epoch 8 is    3.27 and mean absolute error is    1.43.

Epoch 00009: Learning rate is 0.0050.
The average loss for epoch 9 is    2.83 and mean absolute error is    1.30.

Epoch 00010: Learning rate is 0.0050.
The average loss for epoch 10 is    2.96 and mean absolute error is    1.36.

Epoch 00011: Learning rate is 0.0050.
The average loss for epoch 11 is    3.53 and mean absolute error is    1.44.

Epoch 00012: Learning rate is 0.0010.
The average loss for epoch 12 is    3.19 and mean absolute error is    1.39.

Epoch 00013: Learning rate is 0.0010.
The average loss for epoch 13 is    3.22 and mean absolute error is    1.37.

Epoch 00014: Learning rate is 0.0010.
The average loss for epoch 14 is    2.98 and mean absolute error is    1.38.

4.定期存储 ckpt

定期存储checkpoint是很好地习惯,既可以避免模型训练失败重启的高代价,也可以保存多版模型参数,用作后期调整与备选。

model = get_model()

callbacks = [
    keras.callbacks.ModelCheckpoint(
        # Path where to save the model
        # The two parameters below mean that we will overwrite
        # the current checkpoint if and only if
        # the `val_loss` score has improved.
        # The saved model name will include the current epoch.
        filepath="./ckpt/mymodel_{epoch}",
        save_best_only=True,  # Only save a model if `val_loss` has improved.
        monitor="val_loss",
        verbose=1,
    )
]
model.fit(
    x_train, y_train, epochs=2, batch_size=64, callbacks=callbacks, validation_split=0.2
)

这里设定了每一轮epoch存储一次模型,且设置了只有当val_loss提升时才出发模型存储。执行后就可以看到对应路径下已经存好前两轮的ckpt,如果遇到模型意外退出等情况,则可以记载ckpt继续之前的训练。

Tips:

前三个都是通过自定义的方式实现了回调函数,最后一个是通过调用官方的Api实现,全部的API上面也列出了,获取更多细节可以查看回调函数官方Api

更多推荐算法相关深度学习:深度学习导读专栏 

评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

BIT_666

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值