【Keras】keras回调——详解(2)

参考:https://keras.io/api/callbacks/

一、概述

  回调是在训练的各个阶段(如epoch开始前、批量开始前等)执行的动作。可以使用回调来:

  • 每批量训练结束后写入TensorBoard日志,用来监视你的度量(指标)
  • 定期将模型保存到磁盘
  • 提前终止训练
  • 训练期间查看模型内部状态和统计信息
  • 等等
      用法:
      将回调的列表传入model.fit()函数。例:
my_callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=2),
    tf.keras.callbacks.ModelCheckpoint(filepath='model.{epoch:02d}-{val_loss:.2f}.h5'),
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
]
model.fit(dataset, epochs=10, callbacks=my_callbacks)

二、可用的回调

2.1 回调基类tf.keras.callbacks.Callback()

tf.keras.callbacks.Callback()

  作用:用于建立新回调的抽象基类
  属性:
  1.参数:字典。训练参数(verbosity, batch size, number of epochs…)
  2.模型:keras.models.Model的实例,训练模型的参考。

#自定义回调,便于在模型训练、评估或预测的时候使用
#回调集成自tf.keras.callbacks.Callback()基类
'''该类是keras.callbacks.Callback的子类,对方法进行了重写'''
class CustomCallback(keras.callbacks.Callback):
	#1.训练开始时
    def on_train_begin(self, logs=None):
        keys = list(logs.keys())
        print("Starting training; got log keys: {}".format(keys))
	#2.训练结束时
    def on_train_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop training; got log keys: {}".format(keys))
	#3.epoch开始前
    def on_epoch_begin(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))
	#4.epoch结束时
    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))
	#5.测试开始前
    def on_test_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start testing; got log keys: {}".format(keys))
	#6.测试结束后
    def on_test_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop testing; got log keys: {}".format(keys))
	#7.预测开始前
    def on_predict_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start predicting; got log keys: {}".format(keys))
	#8.预测结束后
    def on_predict_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop predicting; got log keys: {}".format(keys))
	#9.训练的每个批次开始前
    def on_train_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: start of batch {}; got log keys: {}".format(batch, keys))
	#10.训练每个批次结束后
    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: end of batch {}; got log keys: {}".format(batch, keys))
	#11.测试每个批次开始前
    def on_test_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))
	#12.测试每个批次结束后
    def on_test_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))
	#13.预测每个批次开始前
    def on_predict_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))
	#14.预测每个每次结束后
    def on_predict_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))

2.2 模型检查点ModelCheckpoint类

tf.keras.callbacks.ModelCheckpoint(
    filepath,
    monitor="val_loss",
    verbose=0,
    save_best_only=False,
    save_weights_only=False,
    mode="auto",
    save_freq="epoch",
    options=None,
    **kwargs
)
参数作用
filepath保存模型文件的路径。filepath 可以包括命名格式选项,可以由 epoch 的值和 logs 的键(由 on_epoch_end 参数传递)来填充。例 weights.{epoch:02d}-{val_loss:.2f}.hdf5
monitor被监测的数据,通常为val_acc 或 val_loss 或 acc 或 loss
verbose信息展示模式,0或1。1表示输出epoch模型保存信息,默认为0表示不输出该信息。信息形如:Epoch 00001: val_acc improved from -inf to 0.49240, saving model to /xxx/checkpoint/model_001-0.3902.h5
save_best_only如果设为save_best_only=True,并且filepath不包含格式选项(如epoch)那么每个新的更好的模型都将覆盖原有模型
模式{自动,最小,最大}。如果 save_best_only=True是否覆盖已经保存的文件,取决于被监测数据的最大化或者最小化, 对于 val_acc,模式就会是 max,而对于 val_loss模式就是 min等等。 在 auto 模式中,方向会自动从被监测的数据的名字中判断。
save_weights_only如果为True,则仅保存模型的权重(model.save_weights(filepath)),否则保存整个模型的权重(model.save(filepath))。
save_freq‘epoch’或整数 。使用’epoch’,回调函数会在每个epoch结束后保存模型。使用整数,回调将在整数个批量结束时保存模型。
options如果 save_weights_only为true则tf.train.CheckpointOptions为可选对象;如果save_weights_only为false则tf.saved_model.SavedOptions 为可选对象
** kwargs其他一些参数。例period。

  作用: 回调一某种频率保存keras模型或模型权重。以一定间隔保存模型或模型权重(在检查点文件中),以后可以加载模型或模型权重并从保存的状态继续训练。
  用法: ModelCheckpoint回调与训练model.fit()结合使用。
  选项:
  1.是仅保留到目前为止已达到“最佳性能”的模型,还是在每个时期结束时保存模型(不考虑性能)。
  2.'best’的定义: 要监视的数量以及应最大化还是最小化。
  3.保存的频率:当前回调支持在每个epoch结束时或在固定数量的训练批次之后保存。
  4.是保存模型还是模型权重。
  例:

EPOCHS = 10
checkpoint_filepath = '/tmp/checkpoint'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_acc',
    mode='max',
    save_best_only=True)

# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])

# The model weights (that are considered the best) are loaded into the model.
model.load_weights(checkpoint_filepath)

2.3 TensorBoard 类

tf.keras.callbacks.TensorBoard(
    log_dir="logs",
    histogram_freq=0,
    write_graph=True,
    write_images=False,
    update_freq="epoch",
    profile_batch=2,
    embeddings_freq=0,
    embeddings_metadata=None,
    **kwargs
)
参数作用
log_dir日志文件保存的目录路径,保存的为模型编译时定义的指标metrics和损失loss
histogram_freq频率(在epoch中),计算模型层的激活和权重直方图。如果设置为0,则不会计算直方图。必须为直方图可视化指定验证数据(或拆分)。
write_graph是否在TensorBoard中可视化图形。当write_graph设置为True时,日志文件可能会变得很大。
write_images是否编写模型权重以在TensorBoard中可视化为图像
update_freq'batch’或’epoch’或整数。‘batch’:每批之后将损失和指标写入TensorBoard;‘epoch’:类似’batch"。整数:假设1000,回调将每1000批将指标和损失写入TensorBoard。过于频繁地向TensorBoard写入可能会减慢训练速度。
profile_batch分析批次以采样计算特征。profile_batch必须是非负整数或整数元组。一对正整数表示要分析的批次范围。默认情况下,它将配置第二批。将profile_batch = 0设置为禁用分析。
embeddings_freq可视化嵌入层的频率(历元)。如果设置为0,则嵌入将不可见。
embeddings_metadata将层名称映射到文件名的字典,在其中保存该嵌入层的元数据。查看 有关元数据文件格式的 详细信息。如果相同的元数据文件用于所有嵌入层,则可以传递字符串。
**kwargs其他一些参数,例embeddings_layer_names:要关注的层名称列表。如果为None或空列表,否则则将监测所有嵌入层;

  作用: 是tensorflow可视化工具
  用法: TensorBoard回调与训练model.fit()结合使用。
  功能:
  1.绘制指标(度量)变化图。
  2.训练图的可视化

  3.激活直方图
  4.采样分析
  例:

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
#对单个批次进行分析,例如 第五批。
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs',
                                                      profile_batch=5)
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
# 分析一系列批次,例如 从10到20。
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir='./logs',                                                   profile_batch='10,20')
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])

2.4 LambdaCallback 类

tf.keras.callbacks.LambdaCallback(
    on_epoch_begin=None,
    on_epoch_end=None,
    on_batch_begin=None,
    on_batch_end=None,
    on_train_begin=None,
    on_train_end=None,
    **kwargs
)
参数作用
on_epoch_begin在每个epoch开始时调用。
on_epoch_end在每个结束时调用。
on_batch_begin在每个批处理的开头调用。
on_batch_end在每个批处理的末尾调用。
on_train_begin在模型训练开始时调用。
on_train_end在模型训练结束时调用。

  作用: 创建简单自定义的回调
  用法: LambdaCallback回调与训练model.fit()结合使用。
  回调是使用匿名函数构造的,这些匿名函数将在适当的时间被调用。回调是需要位置参数的,例如:

  • on_epoch_begin和on_epoch_end两个位置参数: epoch,logs
  • on_batch_begin和on_batch_end两个位置参数: batch,logs
  • on_train_begin并on_train_end期望一个位置参数: logs
# 在每个批次的开始处打印批次号。
batch_print_callback = LambdaCallback(
    on_batch_begin=lambda batch,logs: print(batch))

# 将epoch损失流传输到JSON格式的文件中。 文件内容不是格式正确的JSON,而是每行有一个JSON对象。
import json
json_log = open('loss_log.json', mode='wt', buffering=1)
json_logging_callback = LambdaCallback(
    on_epoch_end=lambda epoch, logs: json_log.write(
        json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'),
    on_train_end=lambda logs: json_log.close()
)

# 完成模型训练后,终止某些过程。
processes = ...
cleanup_callback = LambdaCallback(
    on_train_end=lambda logs: [
        p.terminate() for p in processes if p.is_alive()])

model.fit(...,
          callbacks=[batch_print_callback,
                     json_logging_callback,
                     cleanup_callback])

  
  
  
  
  
  
  

  
  
  
  
  
  
  
  
  

  
  
  

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值