
回调函数在TensorFlow 2.0中是自定义模型行为的强大工具,例如在训练过程中调整学习率或可视化模型。它们在训练、测试和推理的不同阶段被自动调用,可以访问模型状态和统计数据。本文介绍了回调函数的基本概念,如`on_train_begin`和`on_batch_end`等方法,并给出学习率调整的回调函数示例。


!pip3 install tensorflow==2.0.0a0
%matplotlib inline
import tensorflow as tf
from tensorflow import keras
import numpy as np
from matplotlib import pyplot as plt
def get_model():
    model = keras.Sequential()
    model.add(keras.layers.Dense(1, activation='linear', input_dim=784))
    model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.1), loss='mean_squared_error', metrics=['mae'])
    return model


(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255


import datetime
class MyCustomCallback(keras.callbacks.Callback):
    def on_train_batch_begin(self, batch, logs=None):
        print('Training: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))
    def on_train_batch_end(self, batch, logs=None):
        print('Training: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))
    def on_test_batch_begin(self, batch, logs=None):
        print('Evaluating: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))
    def on_test_batch_end(self, batch, logs=None):
        print('Evaluating: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))


model = get_model()
_ = model.fit(x_train, y_train, batch_size=64, epochs=1, steps_per_epoch=5, verbose=0, callbacks=[MyCustomCallback()])
Training: batch 0 begins at 13:26:24.561418
Training: batch 0 ends at 13:26:24.760678
Training: batch 1 begins at 13:26:24.761078
Training: batch 1 ends at 13:26:24.804738
Training: batch 2 begins at 13:26:24.805276
Training: batch 2 ends at 13:26:24.840322
Training: batch 3 begins at 13:26:24.840908
Training: batch 3 ends at 13:26:24.872811
Training: batch 4 begins at 13:26:24.873087
Training: batch 4 ends at 13:26:24.905033



  • fit()
  • fit_generator()
  • evaluate()
  • evaluate_generator()
  • predict()
  • predict_generator()
_ = model.evaluate(x_test, y_test, batch_size=128, verbose=0, steps=5, callbacks=[MyCustomCallback()])
Evaluating: batch 0 begins at 13:26:26.929229
Evaluating: batch 0 ends at 13:26:26.980177
Evaluating: batch 1 begins at 13:26:26.980460
Evaluating: batch 1 ends at 13:26:26.983743
Evaluating: batch 2 begins at 13:26:26.984047
Evaluating: batch 2 ends at 13:26:26.987571
Evaluating: batch 3 begins at 13:26:26.987908
Evaluating: batch 3 ends at 13:26:26.991281
Evaluating: batch 4 begins at 13:26:26.991655
Evaluating: batch 4 ends at 13:26:26.995114




  • on_(train|test|predict)_begin(self, logs=None),在训练、测试和推理开始的时候被调用
  • on_(train|test|predict)_end(self, logs=None), 在训练、测试和推理结束的时候被调用
  • on_(train|test|predict)_batch_begin(self, batch, logs=None),在训练、测试和推理的每个批次开始的时候被调用。其中log是一个包含了两个成员的字典,batchsize,分别代表了批次的编号和尺寸。
  • on_(train|test|predict)_batch_end(self, batch, logs=None),在训练、测试和推理的每个批次结束的时候被调用



  • on_epoch_begin(self, epoch, logs=None),在epoch开始调用
  • on_epoch_end(self, epoch, logs=None),在epoch结束时调用



class LossAndErrorPrintingCallback(tf.keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        print('For batch {}, loss is {:7.2f}.'.format(batch, logs['loss']))

    def on_test_batch_end(self, batch, logs=None):
        print('For batch {}, loss is {:7.2f}.'.format(batch, logs['loss']))

    def on_epoch_end(self, epoch, logs=None):
        print('The average loss for epoch {} is {:7.2f} and mean absolute error is {:7.2f}.'.format(epoch, logs['loss'], logs['mae']))

model = get_model()
_ = model.fit(x_train, y_train, batch_size=64, steps_per_epoch=5, epochs=3, verbose=0, callbacks=[LossAndErrorPrintingCallback()])
For batch 0, loss is   32.18.
For batch 1, loss is  897.99.
For batch 2, loss is   28.40.
For batch 3, loss is    9.22.
For batch 4, loss is    7.45.
The average loss for epoch 0 is  195.05 and mean absolute error is    8.40.
For batch 0, loss is    6.59.
For batch 1, loss is    6.02.
For batch 2, loss is    5.61.
For batch 3, loss is    5.33.
For batch 4, loss is    5.12.
The average loss for epoch 1 is    5.73 and mean absolute error is    2.00.
For batch 0, loss is    4.96.
For batch 1, loss is    4.84.
For batch 2, loss is    4.74.
For batch 3, loss is    4.65.
For batch 4, loss is    4.57.
The average loss for epoch 2 is    4.75 and mean absolute error is    1.78.


_ = model.evaluate(x_test, y_test, batch_size=128, verbose=0, steps=20, callbacks=[LossAndErrorPrintingCallback()])
For batch 0, loss is    4.26.
For batch 1, loss is    4.26.
For batch 2, loss is    4.26.
For batch 3, loss is    4.26.
For batch 4, loss is    4.26.
For batch 5, loss is    4.26.
For batch 6, loss is    4.26.
For batch 7, loss is    4.26.
For batch 8, loss is    4.26.
For batch 9, loss is    4.26.
For batch 10, loss is    4.26.
For batch 11, loss is    4.26.
For batch 12, loss is    4.26.
For batch 13, loss is    4.26.
For batch 14, loss is    4.26.
For batch 15, loss is    4.26.
For batch 16, loss is    4.26.
For batch 17, loss is    4.26.
For batch 18, loss is    4.26.
For batch 19, loss is    4.26.



early stop


class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
    def __init__(self, patience=0):
        super(EarlyStoppingAtMinLoss, self).__init__()
        self.patience = patience
        self.best_weights = None
    def on_train_begin(self, logs=None):
        # 耐心值
        self.wait = 0
        # 停止训练时的epoch
        self.stopped_epoch = 0
        # 初始化一个无穷值作为最好值
        self.best = np.Inf
    def on_epoch_end(self, epoch, logs=None):
        current = logs.get('loss')
        if np.less(current, self.best):
            self.best = current
            self.wait = 0
            self.best_weights = self.model.get_weights()
            self.wait += 1
            if self.wait > self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
                print('Restoring model weights from the end of the best epoch.')
    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0:
            print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
model = get_model()
_ = model.fit(x_train, y_train, batch_size=64, steps_per_epoch=5, epochs=30, verbose=0, callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()])
For batch 0, loss is   24.84.
For batch 1, loss is  957.64.
For batch 2, loss is   22.89.
For batch 3, loss is    9.42.
For batch 4, loss is    7.78.
The average loss for epoch 0 is  204.52 and mean absolute error is    8.35.
For batch 0, loss is    6.81.
For batch 1, loss is    6.15.
For batch 2, loss is    5.68.
For batch 3, loss is    5.34.
For batch 4, loss is    5.11.
The average loss for epoch 1 is    5.82 and mean absolute error is    2.01.
For batch 0, loss is    4.93.
For batch 1, loss is    4.79.
For batch 2, loss is    4.69.
For batch 3, loss is    4.60.
For batch 4, loss is    4.52.
The average loss for epoch 2 is    4.71 and mean absolute error is    1.77.
For batch 0, loss is    4.45.
For batch 1, loss is    4.39.
For batch 2, loss is    4.33.
For batch 3, loss is    4.28.
For batch 4, loss is    4.23.
The average loss for epoch 3 is    4.33 and mean absolute error is    1.67.
For batch 0, loss is    4.18.
For batch 1, loss is    4.14.
For batch 2, loss is    4.10.
For batch 3, loss is    4.07.
For batch 4, loss is    4.08.
The average loss for epoch 4 is    4.12 and mean absolute error is    1.61.
For batch 0, loss is    4.38.
For batch 1, loss is    7.17.
For batch 2, loss is   32.97.
For batch 3, loss is  188.94.
For batch 4, loss is  224.35.
The average loss for epoch 5 is   91.56 and mean absolute error is    7.18.
Restoring model weights from the end of the best epoch.
Epoch 00006: early stopping



class LearningRateScheduler(tf.keras.callbacks.Callback):
    def __init__(self, schedule):
        super(LearningRateScheduler, self).__init__()
        self.schedule = schedule
    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.optimizer, 'lr'):
            raise ValueError('Optimizer must have a "lr" attribute.')
        # Get the current learning rate from model's optimizer.
        lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))
        # Call schedule function to get the scheduled learning rate.
        scheduled_lr = self.schedule(epoch, lr)
        # Set the value back to the optimizer before this epoch starts
        tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
        print('\nEpoch %05d: Learning rate is %6.4f.' % (epoch, scheduled_lr))

    # (epoch to start, learning rate) tuples
    (3, 0.05), (6, 0.01), (9, 0.005), (12, 0.001)

def lr_schedule(epoch, lr):
    """Helper function to retrieve the scheduled learning rate based on epoch."""
    if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
        return lr
    for i in range(len(LR_SCHEDULE)):
        if epoch == LR_SCHEDULE[i][0]:
            return LR_SCHEDULE[i][1]
    return lr

model = get_model()
_ = model.fit(x_train, y_train, batch_size=64, steps_per_epoch=5, epochs=15, verbose=0, callbacks=[LossAndErrorPrintingCallback(), LearningRateScheduler(lr_schedule)])
Epoch 00000: Learning rate is 0.1000.
For batch 0, loss is   26.46.
For batch 1, loss is  934.59.
For batch 2, loss is   23.98.
For batch 3, loss is    8.68.
For batch 4, loss is    7.13.
The average loss for epoch 0 is  200.17 and mean absolute error is    8.30.

Epoch 00001: Learning rate is 0.1000.
For batch 0, loss is    6.31.
For batch 1, loss is    5.76.
For batch 2, loss is    5.37.
For batch 3, loss is    5.10.
For batch 4, loss is    4.91.
The average loss for epoch 1 is    5.49 and mean absolute error is    1.95.

Epoch 00002: Learning rate is 0.1000.
For batch 0, loss is    4.77.
For batch 1, loss is    4.66.
For batch 2, loss is    4.57.
For batch 3, loss is    4.50.
For batch 4, loss is    4.44.
The average loss for epoch 2 is    4.59 and mean absolute error is    1.74.

Epoch 00003: Learning rate is 0.0500.
For batch 0, loss is    4.38.
For batch 1, loss is    4.35.
For batch 2, loss is    4.32.
For batch 3, loss is    4.30.
For batch 4, loss is    4.27.
The average loss for epoch 3 is    4.32 and mean absolute error is    1.67.

Epoch 00004: Learning rate is 0.0500.
For batch 0, loss is    4.25.
For batch 1, loss is    4.22.
For batch 2, loss is    4.20.
For batch 3, loss is    4.17.
For batch 4, loss is    4.15.
The average loss for epoch 4 is    4.20 and mean absolute error is    1.64.

Epoch 00005: Learning rate is 0.0500.
For batch 0, loss is    4.12.
For batch 1, loss is    4.10.
For batch 2, loss is    4.08.
For batch 3, loss is    4.06.
For batch 4, loss is    4.04.
The average loss for epoch 5 is    4.08 and mean absolute error is    1.60.

Epoch 00006: Learning rate is 0.0100.
For batch 0, loss is    4.02.
For batch 1, loss is    4.01.
For batch 2, loss is    4.01.
For batch 3, loss is    4.00.
For batch 4, loss is    4.00.
The average loss for epoch 6 is    4.01 and mean absolute error is    1.58.

Epoch 00007: Learning rate is 0.0100.
For batch 0, loss is    3.99.
For batch 1, loss is    3.99.
For batch 2, loss is    3.98.
For batch 3, loss is    3.98.
For batch 4, loss is    3.97.
The average loss for epoch 7 is    3.98 and mean absolute error is    1.57.

Epoch 00008: Learning rate is 0.0100.
For batch 0, loss is    3.97.
For batch 1, loss is    3.96.
For batch 2, loss is    3.96.
For batch 3, loss is    3.95.
For batch 4, loss is    3.94.
The average loss for epoch 8 is    3.96 and mean absolute error is    1.57.

Epoch 00009: Learning rate is 0.0050.
For batch 0, loss is    3.94.
For batch 1, loss is    3.94.
For batch 2, loss is    3.93.
For batch 3, loss is    3.93.
For batch 4, loss is    3.93.
The average loss for epoch 9 is    3.93 and mean absolute error is    1.56.

Epoch 00010: Learning rate is 0.0050.
For batch 0, loss is    3.92.
For batch 1, loss is    3.92.
For batch 2, loss is    3.91.
For batch 3, loss is    3.91.
For batch 4, loss is    3.91.
The average loss for epoch 10 is    3.91 and mean absolute error is    1.55.

Epoch 00011: Learning rate is 0.0050.
For batch 0, loss is    3.90.
For batch 1, loss is    3.90.
For batch 2, loss is    3.89.
For batch 3, loss is    3.89.
For batch 4, loss is    3.88.
The average loss for epoch 11 is    3.89 and mean absolute error is    1.55.

Epoch 00012: Learning rate is 0.0010.
For batch 0, loss is    3.88.
For batch 1, loss is    3.88.
For batch 2, loss is    3.88.
For batch 3, loss is    3.88.
For batch 4, loss is    3.87.
The average loss for epoch 12 is    3.88 and mean absolute error is    1.54.

Epoch 00013: Learning rate is 0.0010.
For batch 0, loss is    3.87.
For batch 1, loss is    3.87.
For batch 2, loss is    3.87.
For batch 3, loss is    3.87.
For batch 4, loss is    3.87.
The average loss for epoch 13 is    3.87 and mean absolute error is    1.54.

Epoch 00014: Learning rate is 0.0010.
For batch 0, loss is    3.87.
For batch 1, loss is    3.87.
For batch 2, loss is    3.86.
For batch 3, loss is    3.86.
For batch 4, loss is    3.86.
The average loss for epoch 14 is    3.86 and mean absolute error is    1.54.


