文章目录
准备
!pip3 install tensorflow==2.0.0a0
%matplotlib inline
import tensorflow as tf
from tensorflow import keras
import numpy as np
from matplotlib import pyplot as plt
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: tensorflow==2.0.0a0 in /usr/local/lib/python3.7/site-packages (2.0.0a0)
Requirement already satisfied: google-pasta>=0.1.2 in /usr/local/lib/python3.7/site-packages (from tensorflow==2.0.0a0) (0.1.4)
Requirement already satisfied: keras-applications>=1.0.6 in /usr/local/lib/python3.7/site-packages (from tensorflow==2.0.0a0) (1.0.7)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.7/site-packages (from tensorflow==2.0.0a0) (0.33.1)
Requirement already satisfied: tf-estimator-nightly<1.14.0.dev2019030116,>=1.14.0.dev2019030115 in /usr/local/lib/python3.7/site-packages (from tensorflow==2.0.0a0) (1.14.0.dev2019030115)
Requirement already satisfied: absl-py>=0.7.0 in /Users/fei/Library/Python/3.7/lib/python/site-packages (from tensorflow==2.0.0a0) (0.7.0)
Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.7/site-packages (from tensorflow==2.0.0a0) (1.19.0)
Requirement already satisfied: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.7/site-packages (from tensorflow==2.0.0a0) (1.0.9)
Requirement already satisfied: six>=1.10.0 in /Users/fei/Library/Python/3.7/lib/python/site-packages (from tensorflow==2.0.0a0) (1.12.0)
Requirement already satisfied: tb-nightly<1.14.0a20190302,>=1.14.0a20190301 in /usr/local/lib/python3.7/site-packages (from tensorflow==2.0.0a0) (1.14.0a20190301)
Requirement already satisfied: termcolor>=1.1.0 in /Users/fei/Library/Python/3.7/lib/python/site-packages (from tensorflow==2.0.0a0) (1.1.0)
Requirement already satisfied: numpy<2.0,>=1.14.5 in /usr/local/lib/python3.7/site-packages (from tensorflow==2.0.0a0) (1.16.2)
Requirement already satisfied: gast>=0.2.0 in /Users/fei/Library/Python/3.7/lib/python/site-packages (from tensorflow==2.0.0a0) (0.2.2)
Requirement already satisfied: protobuf>=3.6.1 in /usr/local/lib/python3.7/site-packages (from tensorflow==2.0.0a0) (3.7.0)
Requirement already satisfied: astor>=0.6.0 in /Users/fei/Library/Python/3.7/lib/python/site-packages (from tensorflow==2.0.0a0) (0.7.1)
Requirement already satisfied: h5py in /Users/fei/Library/Python/3.7/lib/python/site-packages (from keras-applications>=1.0.6->tensorflow==2.0.0a0) (2.9.0)
Requirement already satisfied: markdown>=2.6.8 in /Users/fei/Library/Python/3.7/lib/python/site-packages (from tb-nightly<1.14.0a20190302,>=1.14.0a20190301->tensorflow==2.0.0a0) (3.0.1)
Requirement already satisfied: werkzeug>=0.11.15 in /Users/fei/Library/Python/3.7/lib/python/site-packages (from tb-nightly<1.14.0a20190302,>=1.14.0a20190301->tensorflow==2.0.0a0) (0.14.1)
Requirement already satisfied: setuptools in /usr/local/lib/python3.7/site-packages (from protobuf>=3.6.1->tensorflow==2.0.0a0) (40.8.0)
keras
回调函数
回调函数是一个强用力的工具,可以在训练、测试和推理的过程中自定义模型的行为,包括读取和修改模型等。一个典型的例子就是tf.keras.callbacks.TensorBoard
可以将模型的训练过程和结果可视化到TensorBoard
中。或者像tf.keras.callbacksModelCheckpoint
可以在训练过程中自动保存模型。在本节中,你可以了解到回调函数的本质、何时调用的、回调能够做什么以及如何从头自己进行编写。本节的末尾还会有几个编写回调函数的例子。
keras
回调函数介绍
对于Keras
,回调函数是一个Python
类,并包含一组在训练、测试和推理的不同过程被调用的函数(主要是在epoch
和batch
的开始或结束时)。回调函数可以获得模型内部的状态和统计数据。可以通过关键字callbacks
向tf.keras.Model.fit()
、tf.keras.Model.evaluate()
和tf.keras.Model.predict()
函数传递一组列表类型的回调函数。这些回调函数会在不同的阶段被自动调用。
下面通过一个简单的序列模型的例子来进行介绍:
def get_model():
model = keras.Sequential()
model.add(keras.layers.Dense(1, activation='linear', input_dim=784))
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=0.1), loss='mean_squared_error', metrics=['mae'])
return model
加载MNIST
数据:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype('float32') / 255
x_test = x_test.reshape(10000, 784).astype('float32') / 255
下面编写一个点单的回调函数,它会在每个batch
的开始和结束被自动调用,并打印当前batch
的序号。
import datetime
class MyCustomCallback(keras.callbacks.Callback):
def on_train_batch_begin(self, batch, logs=None):
print('Training: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))
def on_train_batch_end(self, batch, logs=None):
print('Training: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))
def on_test_batch_begin(self, batch, logs=None):
print('Evaluating: batch {} begins at {}'.format(batch, datetime.datetime.now().time()))
def on_test_batch_end(self, batch, logs=None):
print('Evaluating: batch {} ends at {}'.format(batch, datetime.datetime.now().time()))
将该回调传递给fit
方法:
model = get_model()
_ = model.fit(x_train, y_train, batch_size=64, epochs=1, steps_per_epoch=5, verbose=0, callbacks=[MyCustomCallback()])
Training: batch 0 begins at 13:26:24.561418
Training: batch 0 ends at 13:26:24.760678
Training: batch 1 begins at 13:26:24.761078
Training: batch 1 ends at 13:26:24.804738
Training: batch 2 begins at 13:26:24.805276
Training: batch 2 ends at 13:26:24.840322
Training: batch 3 begins at 13:26:24.840908
Training: batch 3 ends at 13:26:24.872811
Training: batch 4 begins at 13:26:24.873087
Training: batch 4 ends at 13:26:24.905033
接受callback
参数的方法
用户可以将回调函数传递给下面的模型方法:
fit()
fit_generator()
evaluate()
evaluate_generator()
predict()
predict_generator()
_ = model.evaluate(x_test, y_test, batch_size=128, verbose=0, steps=5, callbacks=[MyCustomCallback()])
Evaluating: batch 0 begins at 13:26:26.929229
Evaluating: batch 0 ends at 13:26:26.980177
Evaluating: batch 1 begins at 13:26:26.980460
Evaluating: batch 1 ends at 13:26:26.983743
Evaluating: batch 2 begins at 13:26:26.984047
Evaluating: batch 2 ends at 13:26:26.987571
Evaluating: batch 3 begins at 13:26:26.987908
Evaluating: batch 3 ends at 13:26:26.991281
Evaluating: batch 4 begins at 13:26:26.991655
Evaluating: batch 4 ends at 13:26:26.995114
回调方法的概述
通用方法
对于训练、测试和推理,提供了下列可重写的方法:
on_(train|test|predict)_begin(self, logs=None)
,在训练、测试和推理开始的时候被调用on_(train|test|predict)_end(self, logs=None)
, 在训练、测试和推理结束的时候被调用on_(train|test|predict)_batch_begin(self, batch, logs=None)
,在训练、测试和推理的每个批次开始的时候被调用。其中log
是一个包含了两个成员的字典,batch
和size
,分别代表了批次的编号和尺寸。on_(train|test|predict)_batch_end(self, batch, logs=None)
,在训练、测试和推理的每个批次结束的时候被调用
训练使用的方法
专门对于训练提供了下面的方法:
on_epoch_begin(self, epoch, logs=None)
,在epoch
开始调用on_epoch_end(self, epoch, logs=None)
,在epoch
结束时调用
logs
的特殊用法
对于batch
和epoch
结束时的logs
参数,包含了所有的metrics
和loss
值,可以通过字典的方式获取到。
class LossAndErrorPrintingCallback(tf.keras.callbacks.Callback):
def on_train_batch_end(self, batch, logs=None):
print('For batch {}, loss is {:7.2f}.'.format(batch, logs['loss']))
def on_test_batch_end(self, batch, logs=None):
print('For batch {}, loss is {:7.2f}.'.format(batch, logs['loss']))
def on_epoch_end(self, epoch, logs=None):
print('The average loss for epoch {} is {:7.2f} and mean absolute error is {:7.2f}.'.format(epoch, logs['loss'], logs['mae']))
model = get_model()
_ = model.fit(x_train, y_train, batch_size=64, steps_per_epoch=5, epochs=3, verbose=0, callbacks=[LossAndErrorPrintingCallback()])
For batch 0, loss is 32.18.
For batch 1, loss is 897.99.
For batch 2, loss is 28.40.
For batch 3, loss is 9.22.
For batch 4, loss is 7.45.
The average loss for epoch 0 is 195.05 and mean absolute error is 8.40.
For batch 0, loss is 6.59.
For batch 1, loss is 6.02.
For batch 2, loss is 5.61.
For batch 3, loss is 5.33.
For batch 4, loss is 5.12.
The average loss for epoch 1 is 5.73 and mean absolute error is 2.00.
For batch 0, loss is 4.96.
For batch 1, loss is 4.84.
For batch 2, loss is 4.74.
For batch 3, loss is 4.65.
For batch 4, loss is 4.57.
The average loss for epoch 2 is 4.75 and mean absolute error is 1.78.
当然你也可以将该回调传递给evaluate()
方法。
_ = model.evaluate(x_test, y_test, batch_size=128, verbose=0, steps=20, callbacks=[LossAndErrorPrintingCallback()])
For batch 0, loss is 4.26.
For batch 1, loss is 4.26.
For batch 2, loss is 4.26.
For batch 3, loss is 4.26.
For batch 4, loss is 4.26.
For batch 5, loss is 4.26.
For batch 6, loss is 4.26.
For batch 7, loss is 4.26.
For batch 8, loss is 4.26.
For batch 9, loss is 4.26.
For batch 10, loss is 4.26.
For batch 11, loss is 4.26.
For batch 12, loss is 4.26.
For batch 13, loss is 4.26.
For batch 14, loss is 4.26.
For batch 15, loss is 4.26.
For batch 16, loss is 4.26.
For batch 17, loss is 4.26.
For batch 18, loss is 4.26.
For batch 19, loss is 4.26.
回调函数编写例子
下面是两个从头编写回调函数的例子:
early stop
该例子实现了当获取到最小的loss
时停止训练(通过一个model.stop_training
的布尔值)。同时用户可以传递一个可选的patience
,来指示在达到最小的loss
后,等待多少epoch
后才停止。
tf.keras.callbacks.EarlyStopping
提供了一个更加复杂和通用的实现,但是我们这里仅考虑简单的情况。
class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
def __init__(self, patience=0):
super(EarlyStoppingAtMinLoss, self).__init__()
self.patience = patience
self.best_weights = None
def on_train_begin(self, logs=None):
# 耐心值
self.wait = 0
# 停止训练时的epoch
self.stopped_epoch = 0
# 初始化一个无穷值作为最好值
self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
current = logs.get('loss')
if np.less(current, self.best):
self.best = current
self.wait = 0
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait > self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
print('Restoring model weights from the end of the best epoch.')
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print('Epoch %05d: early stopping' % (self.stopped_epoch + 1))
model = get_model()
_ = model.fit(x_train, y_train, batch_size=64, steps_per_epoch=5, epochs=30, verbose=0, callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()])
For batch 0, loss is 24.84.
For batch 1, loss is 957.64.
For batch 2, loss is 22.89.
For batch 3, loss is 9.42.
For batch 4, loss is 7.78.
The average loss for epoch 0 is 204.52 and mean absolute error is 8.35.
For batch 0, loss is 6.81.
For batch 1, loss is 6.15.
For batch 2, loss is 5.68.
For batch 3, loss is 5.34.
For batch 4, loss is 5.11.
The average loss for epoch 1 is 5.82 and mean absolute error is 2.01.
For batch 0, loss is 4.93.
For batch 1, loss is 4.79.
For batch 2, loss is 4.69.
For batch 3, loss is 4.60.
For batch 4, loss is 4.52.
The average loss for epoch 2 is 4.71 and mean absolute error is 1.77.
For batch 0, loss is 4.45.
For batch 1, loss is 4.39.
For batch 2, loss is 4.33.
For batch 3, loss is 4.28.
For batch 4, loss is 4.23.
The average loss for epoch 3 is 4.33 and mean absolute error is 1.67.
For batch 0, loss is 4.18.
For batch 1, loss is 4.14.
For batch 2, loss is 4.10.
For batch 3, loss is 4.07.
For batch 4, loss is 4.08.
The average loss for epoch 4 is 4.12 and mean absolute error is 1.61.
For batch 0, loss is 4.38.
For batch 1, loss is 7.17.
For batch 2, loss is 32.97.
For batch 3, loss is 188.94.
For batch 4, loss is 224.35.
The average loss for epoch 5 is 91.56 and mean absolute error is 7.18.
Restoring model weights from the end of the best epoch.
Epoch 00006: early stopping
学习率调整
当下训练中另一件经常用到的就是根据训练的进行来调整学习率。本例中我们演示了如何使用回调函数来动态调整学习率。
keras.callbacks.LearningRateScheduler
提供了更加复杂的实现。
class LearningRateScheduler(tf.keras.callbacks.Callback):
"""
schedule是一个方法,接受epochs和当前学习率,然后返回一个新的学习率
"""
def __init__(self, schedule):
super(LearningRateScheduler, self).__init__()
self.schedule = schedule
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'lr'):
raise ValueError('Optimizer must have a "lr" attribute.')
# Get the current learning rate from model's optimizer.
lr = float(tf.keras.backend.get_value(self.model.optimizer.lr))
# Call schedule function to get the scheduled learning rate.
scheduled_lr = self.schedule(epoch, lr)
# Set the value back to the optimizer before this epoch starts
tf.keras.backend.set_value(self.model.optimizer.lr, scheduled_lr)
print('\nEpoch %05d: Learning rate is %6.4f.' % (epoch, scheduled_lr))
LR_SCHEDULE = [
# (epoch to start, learning rate) tuples
(3, 0.05), (6, 0.01), (9, 0.005), (12, 0.001)
]
def lr_schedule(epoch, lr):
"""Helper function to retrieve the scheduled learning rate based on epoch."""
if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
return lr
for i in range(len(LR_SCHEDULE)):
if epoch == LR_SCHEDULE[i][0]:
return LR_SCHEDULE[i][1]
return lr
model = get_model()
_ = model.fit(x_train, y_train, batch_size=64, steps_per_epoch=5, epochs=15, verbose=0, callbacks=[LossAndErrorPrintingCallback(), LearningRateScheduler(lr_schedule)])
Epoch 00000: Learning rate is 0.1000.
For batch 0, loss is 26.46.
For batch 1, loss is 934.59.
For batch 2, loss is 23.98.
For batch 3, loss is 8.68.
For batch 4, loss is 7.13.
The average loss for epoch 0 is 200.17 and mean absolute error is 8.30.
Epoch 00001: Learning rate is 0.1000.
For batch 0, loss is 6.31.
For batch 1, loss is 5.76.
For batch 2, loss is 5.37.
For batch 3, loss is 5.10.
For batch 4, loss is 4.91.
The average loss for epoch 1 is 5.49 and mean absolute error is 1.95.
Epoch 00002: Learning rate is 0.1000.
For batch 0, loss is 4.77.
For batch 1, loss is 4.66.
For batch 2, loss is 4.57.
For batch 3, loss is 4.50.
For batch 4, loss is 4.44.
The average loss for epoch 2 is 4.59 and mean absolute error is 1.74.
Epoch 00003: Learning rate is 0.0500.
For batch 0, loss is 4.38.
For batch 1, loss is 4.35.
For batch 2, loss is 4.32.
For batch 3, loss is 4.30.
For batch 4, loss is 4.27.
The average loss for epoch 3 is 4.32 and mean absolute error is 1.67.
Epoch 00004: Learning rate is 0.0500.
For batch 0, loss is 4.25.
For batch 1, loss is 4.22.
For batch 2, loss is 4.20.
For batch 3, loss is 4.17.
For batch 4, loss is 4.15.
The average loss for epoch 4 is 4.20 and mean absolute error is 1.64.
Epoch 00005: Learning rate is 0.0500.
For batch 0, loss is 4.12.
For batch 1, loss is 4.10.
For batch 2, loss is 4.08.
For batch 3, loss is 4.06.
For batch 4, loss is 4.04.
The average loss for epoch 5 is 4.08 and mean absolute error is 1.60.
Epoch 00006: Learning rate is 0.0100.
For batch 0, loss is 4.02.
For batch 1, loss is 4.01.
For batch 2, loss is 4.01.
For batch 3, loss is 4.00.
For batch 4, loss is 4.00.
The average loss for epoch 6 is 4.01 and mean absolute error is 1.58.
Epoch 00007: Learning rate is 0.0100.
For batch 0, loss is 3.99.
For batch 1, loss is 3.99.
For batch 2, loss is 3.98.
For batch 3, loss is 3.98.
For batch 4, loss is 3.97.
The average loss for epoch 7 is 3.98 and mean absolute error is 1.57.
Epoch 00008: Learning rate is 0.0100.
For batch 0, loss is 3.97.
For batch 1, loss is 3.96.
For batch 2, loss is 3.96.
For batch 3, loss is 3.95.
For batch 4, loss is 3.94.
The average loss for epoch 8 is 3.96 and mean absolute error is 1.57.
Epoch 00009: Learning rate is 0.0050.
For batch 0, loss is 3.94.
For batch 1, loss is 3.94.
For batch 2, loss is 3.93.
For batch 3, loss is 3.93.
For batch 4, loss is 3.93.
The average loss for epoch 9 is 3.93 and mean absolute error is 1.56.
Epoch 00010: Learning rate is 0.0050.
For batch 0, loss is 3.92.
For batch 1, loss is 3.92.
For batch 2, loss is 3.91.
For batch 3, loss is 3.91.
For batch 4, loss is 3.91.
The average loss for epoch 10 is 3.91 and mean absolute error is 1.55.
Epoch 00011: Learning rate is 0.0050.
For batch 0, loss is 3.90.
For batch 1, loss is 3.90.
For batch 2, loss is 3.89.
For batch 3, loss is 3.89.
For batch 4, loss is 3.88.
The average loss for epoch 11 is 3.89 and mean absolute error is 1.55.
Epoch 00012: Learning rate is 0.0010.
For batch 0, loss is 3.88.
For batch 1, loss is 3.88.
For batch 2, loss is 3.88.
For batch 3, loss is 3.88.
For batch 4, loss is 3.87.
The average loss for epoch 12 is 3.88 and mean absolute error is 1.54.
Epoch 00013: Learning rate is 0.0010.
For batch 0, loss is 3.87.
For batch 1, loss is 3.87.
For batch 2, loss is 3.87.
For batch 3, loss is 3.87.
For batch 4, loss is 3.87.
The average loss for epoch 13 is 3.87 and mean absolute error is 1.54.
Epoch 00014: Learning rate is 0.0010.
For batch 0, loss is 3.87.
For batch 1, loss is 3.87.
For batch 2, loss is 3.86.
For batch 3, loss is 3.86.
For batch 4, loss is 3.86.
The average loss for epoch 14 is 3.86 and mean absolute error is 1.54.
更多关于内置回调函数,可以阅读TF的API文档。