编写回调函数

回调函数在TensorFlow 2.0中是自定义模型行为的强大工具,例如在训练过程中调整学习率或可视化模型。它们在训练、测试和推理的不同阶段被自动调用,可以访问模型状态和统计数据。本文介绍了回调函数的基本概念,如`on_train_begin`和`on_batch_end`等方法,并给出学习率调整的回调函数示例。
摘要由CSDN通过智能技术生成

准备

!pip3 install tensorflow==2.0.0a0
%matplotlib inline
import tensorflow as tf
from tensorflow import keras
import numpy as np
from matplotlib import pyplot as plt
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: tensorflow==2.0.0a0 in /usr/local/lib/python3.7/site-packages (2.0.0a0)
Requirement already satisfied: google-pasta>=0.1.2 in /usr/local/lib/python3.7/site-packages (from tensorflow==2.0.0a0) (0.1.4)
Requirement already satisfied: keras-applications>=1.0.6 in /usr/local/lib/python3.7/site-packages (from tensorflow==2.0.0a0) (1.0.7)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/site-packages (from tensorflow==2.0.0a0) (0.33.1)
Requirement already satisfied: tf-estimator-nightly<1.14.0.dev2019030116,>=1.14.0.dev2019030115 in /usr/local/lib/python3.7/site-packages (from tensorflow==2.0.0a0) (1.14.0.dev2019030115)
Requirement already satisfied: absl-py>=0.7.0 in /Users/fei/Library/Python/3.7/lib/python/site-packages (from tensorflow==2.0.0a0) (0.7.0)
Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.7/site-packages (from tensorflow==2.0.0a0) (1.19.0)
Requirement already satisfied: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.7/site-packages (from tensorflow==2.0.0a0) (1.0.9)
Requirement already satisfied: six>=1.10.0 in /Users/fei/Library/Python/3.7/lib/python/site-packages (from tensorflow==2.0.0a0) (1.12.0)
Requirement already satisfied: tb-nightly<1.14.0a20190302,>=1.14.0a20190301 in /usr/local/lib/python3.7/site-packages (from tensorflow==2.0.0a0) (1.14.0a20190301)
Requirement already satisfied: termcolor>=1.1.0 in /Users/fei/Library/Python/3.7/lib/python/site-packages (from tensorflow==2.0.0a0) (1.1.0)
Requirement already satisfied: numpy<2.0,>=1.14.5 in /usr/local/lib/python3.7/site-packages (from tensorflow==2.0.0a0) (1.16.2)
Requirement already satisfied: gast>=0.2.0 in /Users/fei/Library/Python/3.7/lib/python/site-packages (from tensorflow==2.0.0a0) (0.2.2)
Requirement already satisfied: protobuf>=3.6.1 in /usr/local/lib/python3.7/site-packages (from tensorflow==2.0.0a0) (3.7.0)
Requirement already satisfied: astor>=0.6.0 in /Users/fei/Library/Python/3.7/lib/python/site-packages (from tensorflow==2.0.0a0) (0.7.1)
Requirement already satisfied: h5py in /Users/fei/Library/Python/3.7/lib/python/site-packages (from keras-applications>=1.0.6->tensorflow==2.0.0a0) (2.9.0)
Requirement already satisfied: markdown>=2.6.8 in /Users/fei/Library/Python/3.7/lib/python/site-packages (from tb-nightly<1.14.0a20190302,>=1.14.0a20190301->tensorflow==2.0.0a0) (3.0.1)
Requirement already satisfied: werkzeug>=0.11.15 in /Users/fei/Library/Python/3.7/lib/python/site-packages (from tb-nightly<1.14.0a20190302,>=1.14.0a20190301->tensorflow==2.0.0a0) (0.14.1)
Requirement already satisfied: setuptools in /usr/local/lib/python3.7/site-packages (from protobuf>=3.6.1->tensorflow==2.0.0a0) (40.8.0)

keras回调函数

回调函数是一个强用力的工具,可以在训练、测试和推理的过程中自定义模型的行为,包括读取和修改模型等。一个典型的例子就是tf.keras.callbacks.TensorBoard可以将模型的训练过程和结果可视化到TensorBoard中。或者像tf.keras.callbacksModelCheckpoint可以在训练过程中自动保存模型。在本节中,你可以了解到回调函数的本质、何时调用的、回调能够做什么以及如何从头自己进行编写。本节的末尾还会有几个编写回调函数的例子。

keras回调函数介绍

对于Keras,回调函数是一个Python类,并包含一组在训练、测试和推理的不同过程被调用的函数(主要是在epochbatch的开始或结束时)。回调函数可以获得模型内部的状态和统计数据。可以通过关键字callbackstf.keras.Model.fit()tf.keras.Model.evaluate()tf.keras.Model.predict()函数传递一组列表类型的回调函数。这些回调函数会在不同的阶段被自动调用。
下面通过一个简单的序列模型的例子来进行介绍:

def get_model():
    model = keras.Sequential()
    model.add(keras.layers.Dense(1, activation='linear', input_dim=784))
    model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.1), loss='mean_squared_error', metrics=['mae'])
    return model

加载MNIST数据:

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255

下面编写一个点单的回调函数,它会在每个batch的开始和结束被自动调用,并打印当前batch的序号。

import datetime
class MyCustomCallback(keras.callbacks.Callback):
    def on_train_batch_begin(self, batch, logs=None):
        print('Training: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))
    
    def on_train_batch_end(self, batch, logs=None):
        print('Training: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))
    
    def on_test_batch_begin(self, batch, logs=None):
        print('Evaluating: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))
    
    def on_test_batch_end(self, batch, logs=None):
        print('Evaluating: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))

将该回调传递给fit方法:

model = get_model()
_ = model.fit(x_train, y_train, batch_size=64, epochs=1, steps_per_epoch=5, verbose=0, callbacks=[MyCustomCallback()])
Training: batch 0 begins at 13:26:24.561418
Training: batch 0 ends at 13:26:24.760678
Training: batch 1 begins at 13:26:24.761078
Training: batch 1 ends at 13:26:24.804738
Training: batch 2 begins at 13:26:24.805276
Training: batch 2 ends at 13:26:24.840322
Training: batch 3 begins at 13:26:24.840908
Training: batch 3 ends at 13:26:24.872811
Training: batch 4 begins at 13:26:24.873087
Training: batch 4 ends at 13:26:24.905033

接受callback参数的方法

用户可以将回调函数传递给下面的模型方法:

  • fit()
  • fit_generator()
  • evaluate()
  • evaluate_generator()
  • predict()
  • predict_generator()
_ = model.evaluate(x_test, y_test, batch_size=128, verbose=0, steps=5, callbacks=[MyCustomCallback()])
Evaluating: batch 0 begins at 13:26:26.929229
Evaluating: batch 0 ends at 13:26:26.980177
Evaluating: batch 1 begins at 13:26:26.980460
Evaluating: batch 1 ends at 13:26:26.983743
Evaluating: batch 2 begins at 13:26:26.984047
Evaluating: batch 2 ends at 13:26:26.987571
Evaluating: batch 3 begins at 13:26:26.987908
Evaluating: batch 3 ends at 13:26:26.991281
Evaluating: batch 4 begins at 13:26:26.991655
Evaluating: batch 4 ends at 13:26:26.995114

回调方法的概述

通用方法

对于训练、测试和推理,提供了下列可重写的方法:

  • on_(train|test|predict)_begin(self, logs=None),在训练、测试和推理开始的时候被调用
  • on_(train|test|predict)_end(self, logs=None), 在训练、测试和推理结束的时候被调用
  • on_(train|test|predict)_batch_begin(self, batch, logs=None),在训练、测试和推理的每个批次开始的时候被调用。其中log是一个包含了两个成员的字典,batchsize,分别代表了批次的编号和尺寸。
  • on_(train|test|predict)_batch_end(self, batch, logs=None),在训练、测试和推理的每个批次结束的时候被调用

训练使用的方法

专门对于训练提供了下面的方法:

  • on_epoch_begin(self, epoch, logs=None),在epoch开始调用
  • on_epoch_end(self, epoch, logs=None),在epoch结束时调用

logs的特殊用法

对于batchepoch结束时的logs参数,包含了所有的metricsloss值,可以通过字典的方式获取到。

class LossAndErrorPrintingCallback(tf.keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        print('For batch {}, loss is {:7.2f}.'.format(batch, logs['loss']))

    def on_test_batch_end(self, batch, logs=None):
        print('For batch {}, loss is {:7.2f}.'.format(batch, logs['loss']))

    def on_epoch_end(self, epoch, logs=None):
        print('The average loss for epoch {} is {:7.2f} and mean absolute error is {:7.2f}.'.format(epoch, logs['loss'], logs['mae']))

model = get_model()
_ = model.fit(x_train, y_train, batch_size=64, steps_per_epoch=5, epochs=3, verbose=0, callbacks=[LossAndErrorPrintingCallback()])
For batch 0, loss is   32.18.
For batch 1, loss is  897.99.
For batch 2, loss is   28.40.
For batch 3, loss is    9.22.
For batch 4, loss is    7.45.
The average loss for epoch 0 is  195.05 and mean absolute error is    8.40.
For batch 0, loss is    6.59.
For batch 1, loss is    6.02.
For batch 2, loss is    5.61.
For batch 3, loss is    5.33.
For batch 4, loss is    5.12.
The average loss for epoch 1 is    5.73 and mean absolute error is    2.00.
For batch 0, loss is    4.96.
For batch 1, loss is    4.84.
For batch 2, loss is    4.74.
For batch 3, loss is    4.65.
For batch 4, loss is    4.57.
The average loss for epoch 2 is    4.75 and mean absolute error is    1.78.

当然你也可以将该回调传递给evaluate()方法。

_ = model.evaluate(x_test, y_test, batch_size=128, verbose=0, steps=20, callbacks=[LossAndErrorPrintingCallback()])
For batch 0, loss is    4.26.
For batch 1, loss is    4.26.
For batch 2, loss is    4.26.
For batch 3, loss is    4.26.
For batch 4, loss is    4.26.
For batch 5, loss is    4.26.
For batch 6, loss is    4.26.
For batch 7, loss is    4.26.
For batch 8, loss is    4.26.
For batch 9, loss is    4.26.
For batch 10, loss is    4.26.
For batch 11, loss is    4.26.
For batch 12, loss is    4.26.
For batch 13, loss is    4.26.
For batch 14, loss is    4.26.
For batch 15, loss is    4.26.
For batch 16, loss is    4.26.
For batch 17, loss is    4.26.
For batch 18, loss is    4.26.
For batch 19, loss is    4.26.

回调函数编写例子

下面是两个从头编写回调函数的例子:

early stop

该例子实现了当获取到最小的loss时停止训练(通过一个model.stop_training的布尔值)。同时用户可以传递一个可选的patience,来指示在达到最小的loss后,等待多少epoch后才停止。
tf.keras.callbacks.EarlyStopping提供了一个更加复杂和通用的实现,但是我们这里仅考虑简单的情况。

class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
    def __init__(self, patience=0):
        super(EarlyStoppingAtMinLoss, self).__init__()
        self.patience = patience
        self.best_weights = None
        
    def on_train_begin(self, logs=None):
        # 耐心值
        self.wait = 0
        # 停止训练时的epoch
        self.stopped_epoch = 0
        # 初始化一个无穷值作为最好值
        self.best = np.Inf
        
    def on_epoch_end(self, epoch, logs=None):
        current = logs.get('loss')
        if np.less(current, self.best):
            self.best = current
            self.wait = 0
            self.best_weights = self.model.get_weights()
        else:
            self.wait += 1
            if self.wait > self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
                print('Restoring model weights from the end of the best epoch.')
                self.model.set_weights(self.best_weights)
        
    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0:
            print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
model = get_model()
_ = model.fit(x_train, y_train, batch_size=64, steps_per_epoch=5, epochs=30, verbose=0, callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()])
For batch 0, loss is   24.84.
For batch 1, loss is  957.64.
For batch 2, loss is   22.89.
For batch 3, loss is    9.42.
For batch 4, loss is    7.78.
The average loss for epoch 0 is  204.52 and mean absolute error is    8.35.
For batch 0, loss is    6.81.
For batch 1, loss is    6.15.
For batch 2, loss is    5.68.
For batch 3, loss is    5.34.
For batch 4, loss is    5.11.
The average loss for epoch 1 is    5.82 and mean absolute error is    2.01.
For batch 0, loss is    4.93.
For batch 1, loss is    4.79.
For batch 2, loss is    4.69.
For batch 3, loss is    4.60.
For batch 4, loss is    4.52.
The average loss for epoch 2 is    4.71 and mean absolute error is    1.77.
For batch 0, loss is    4.45.
For batch 1, loss is    4.39.
For batch 2, loss is    4.33.
For batch 3, loss is    4.28.
For batch 4, loss is    4.23.
The average loss for epoch 3 is    4.33 and mean absolute error is    1.67.
For batch 0, loss is    4.18.
For batch 1, loss is    4.14.
For batch 2, loss is    4.10.
For batch 3, loss is    4.07.
For batch 4, loss is    4.08.
The average loss for epoch 4 is    4.12 and mean absolute error is    1.61.
For batch 0, loss is    4.38.
For batch 1, loss is    7.17.
For batch 2, loss is   32.97.
For batch 3, loss is  188.94.
For batch 4, loss is  224.35.
The average loss for epoch 5 is   91.56 and mean absolute error is    7.18.
Restoring model weights from the end of the best epoch.
Epoch 00006: early stopping

学习率调整

当下训练中另一件经常用到的就是根据训练的进行来调整学习率。本例中我们演示了如何使用回调函数来动态调整学习率。
keras.callbacks.LearningRateScheduler提供了更加复杂的实现。

class LearningRateScheduler(tf.keras.callbacks.Callback):
    """
    schedule是一个方法,接受epochs和当前学习率,然后返回一个新的学习率
    """
    def __init__(self, schedule):
        super(LearningRateScheduler, self).__init__()
        self.schedule = schedule
        
    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.optimizer, 'lr'):
            raise ValueError('Optimizer must have a "lr" attribute.')
        # Get the current learning rate from model's optimizer.
        lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))
        # Call schedule function to get the scheduled learning rate.
        scheduled_lr = self.schedule(epoch, lr)
        # Set the value back to the optimizer before this epoch starts
        tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
        print('\nEpoch %05d: Learning rate is %6.4f.' % (epoch, scheduled_lr))

LR_SCHEDULE = [
    # (epoch to start, learning rate) tuples
    (3, 0.05), (6, 0.01), (9, 0.005), (12, 0.001)
]

def lr_schedule(epoch, lr):
    """Helper function to retrieve the scheduled learning rate based on epoch."""
    if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
        return lr
    for i in range(len(LR_SCHEDULE)):
        if epoch == LR_SCHEDULE[i][0]:
            return LR_SCHEDULE[i][1]
    return lr

model = get_model()
_ = model.fit(x_train, y_train, batch_size=64, steps_per_epoch=5, epochs=15, verbose=0, callbacks=[LossAndErrorPrintingCallback(), LearningRateScheduler(lr_schedule)])
Epoch 00000: Learning rate is 0.1000.
For batch 0, loss is   26.46.
For batch 1, loss is  934.59.
For batch 2, loss is   23.98.
For batch 3, loss is    8.68.
For batch 4, loss is    7.13.
The average loss for epoch 0 is  200.17 and mean absolute error is    8.30.

Epoch 00001: Learning rate is 0.1000.
For batch 0, loss is    6.31.
For batch 1, loss is    5.76.
For batch 2, loss is    5.37.
For batch 3, loss is    5.10.
For batch 4, loss is    4.91.
The average loss for epoch 1 is    5.49 and mean absolute error is    1.95.

Epoch 00002: Learning rate is 0.1000.
For batch 0, loss is    4.77.
For batch 1, loss is    4.66.
For batch 2, loss is    4.57.
For batch 3, loss is    4.50.
For batch 4, loss is    4.44.
The average loss for epoch 2 is    4.59 and mean absolute error is    1.74.

Epoch 00003: Learning rate is 0.0500.
For batch 0, loss is    4.38.
For batch 1, loss is    4.35.
For batch 2, loss is    4.32.
For batch 3, loss is    4.30.
For batch 4, loss is    4.27.
The average loss for epoch 3 is    4.32 and mean absolute error is    1.67.

Epoch 00004: Learning rate is 0.0500.
For batch 0, loss is    4.25.
For batch 1, loss is    4.22.
For batch 2, loss is    4.20.
For batch 3, loss is    4.17.
For batch 4, loss is    4.15.
The average loss for epoch 4 is    4.20 and mean absolute error is    1.64.

Epoch 00005: Learning rate is 0.0500.
For batch 0, loss is    4.12.
For batch 1, loss is    4.10.
For batch 2, loss is    4.08.
For batch 3, loss is    4.06.
For batch 4, loss is    4.04.
The average loss for epoch 5 is    4.08 and mean absolute error is    1.60.

Epoch 00006: Learning rate is 0.0100.
For batch 0, loss is    4.02.
For batch 1, loss is    4.01.
For batch 2, loss is    4.01.
For batch 3, loss is    4.00.
For batch 4, loss is    4.00.
The average loss for epoch 6 is    4.01 and mean absolute error is    1.58.

Epoch 00007: Learning rate is 0.0100.
For batch 0, loss is    3.99.
For batch 1, loss is    3.99.
For batch 2, loss is    3.98.
For batch 3, loss is    3.98.
For batch 4, loss is    3.97.
The average loss for epoch 7 is    3.98 and mean absolute error is    1.57.

Epoch 00008: Learning rate is 0.0100.
For batch 0, loss is    3.97.
For batch 1, loss is    3.96.
For batch 2, loss is    3.96.
For batch 3, loss is    3.95.
For batch 4, loss is    3.94.
The average loss for epoch 8 is    3.96 and mean absolute error is    1.57.

Epoch 00009: Learning rate is 0.0050.
For batch 0, loss is    3.94.
For batch 1, loss is    3.94.
For batch 2, loss is    3.93.
For batch 3, loss is    3.93.
For batch 4, loss is    3.93.
The average loss for epoch 9 is    3.93 and mean absolute error is    1.56.

Epoch 00010: Learning rate is 0.0050.
For batch 0, loss is    3.92.
For batch 1, loss is    3.92.
For batch 2, loss is    3.91.
For batch 3, loss is    3.91.
For batch 4, loss is    3.91.
The average loss for epoch 10 is    3.91 and mean absolute error is    1.55.

Epoch 00011: Learning rate is 0.0050.
For batch 0, loss is    3.90.
For batch 1, loss is    3.90.
For batch 2, loss is    3.89.
For batch 3, loss is    3.89.
For batch 4, loss is    3.88.
The average loss for epoch 11 is    3.89 and mean absolute error is    1.55.

Epoch 00012: Learning rate is 0.0010.
For batch 0, loss is    3.88.
For batch 1, loss is    3.88.
For batch 2, loss is    3.88.
For batch 3, loss is    3.88.
For batch 4, loss is    3.87.
The average loss for epoch 12 is    3.88 and mean absolute error is    1.54.

Epoch 00013: Learning rate is 0.0010.
For batch 0, loss is    3.87.
For batch 1, loss is    3.87.
For batch 2, loss is    3.87.
For batch 3, loss is    3.87.
For batch 4, loss is    3.87.
The average loss for epoch 13 is    3.87 and mean absolute error is    1.54.

Epoch 00014: Learning rate is 0.0010.
For batch 0, loss is    3.87.
For batch 1, loss is    3.87.
For batch 2, loss is    3.86.
For batch 3, loss is    3.86.
For batch 4, loss is    3.86.
The average loss for epoch 14 is    3.86 and mean absolute error is    1.54.

更多关于内置回调函数,可以阅读TF的API文档。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值