10_ tf.keras Callbacks概述

1. tf.keras Callbacks是什么

Keras官方文档:回调Callbacks

Callbacks的本质是一组函数对象,代码层面就是一个Python List在训练过程中的特定时期被执行,这些函数对象可以在训练过程中访问,保存或者修改训练中的参数,相当于在训练之前写好了几个锦囊,这些锦囊会在特定的时间被打开并且执行。用好Callbacks训练过程将会是一个很愉快的过程。

典型应用场景:

  • 解决训练之后的失控问题。
  • 不知道训练多少轮可以得到想要的结果,这个时候可以通过Callbacks设置当模型不能进一步优化时停止训练。
  • 通过Tensorboard等工具查看训练模型的内在状态和统计,全面直观的监控训练过程。

2. 使用Callbacks

  • 1.实例化Callback
  • 2.以Python List形式传给model.fit()方法的callbacks参数。

例如实现Plateau学习率策略:

reduce_lr = ReduceLROnPlateau(monitor='val_loss',factor=0.2,
								patience=5,min_lr=0.001)
model.fit(x,y,callbacks=[reduce_lr])

3. tf.keras内置Callback函数

动态模型保存 ModelCheckpoint

tf.keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', 
								verbose=0, save_best_only=False, 
								save_weights_only=False, 
								mode='auto', period=1)

主要参数:

  • filepath:模型保存的路径.
  • monitorsave_best_only:监控monitor指定的指标,设置save_best_onlyTrue时可以保存最好的模型,防止模型参数占用太多的硬盘容量.
  • save_weights_only:为True时只保存权重,等于model.save_weights(),为False是保存权重和网络,等于model.save().

动态训练终止 EarlyStopping

tf.keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, 
								patience=0, verbose=0, 
								mode='auto', baseline=None, 
								restore_best_weights=False)

主要参数:

  • min_delta:在被监测的数据中被认为是提升的最小变化, 例如,小于 min_delta 的绝对变化会被认为没有提升。
  • patience: 在监测质量经过多少轮次没有进度时即停止。如果验证频率 (model.fit(validation_freq=5)) 大于 1,则可能不会在每个轮次都产生验证数.
  • mode: {auto, min, max} 其中之一,在 min 模式中,当被监测的数据停止下降训练就会停止;在 max 模式中当被监测的数据停止上升,训练就会停止;在 auto 模式中方向会自动从被监测的数据的名字中判断出来.
  • baseline: 要监控的数量的基准值。 如果模型没有显示基准的改善,训练将停止.
  • restore_best_weights: 是否从具有监测数量的最佳值的时期恢复模型权重。 如果为 False,则使用在训练的最后一步获得的模型权重.

远程事件监控 RemoteMonitor

tf.keras.callbacks.RemoteMonitor(root='http://localhost:9000',
								 path='/publish/epoch/end/',
								 field='data', headers=None, 
								 send_as_json=False)

主要参数

  • root: 目标服务器的根地址.
  • path: 相对于 root 的路径,事件数据被送达的地址.
  • field: JSON ,数据被保存的领域.
  • headers: 可选自定义的 HTTP 的头字段.
  • send_as_json: 请求是否应该以 application/json 格式发送.

自定义动态学习率 ReduceLROnPlateau

tf.keras.callbacks.LearningRateScheduler(schedule, verbose=0)

主要参数

  • schedule: 一个函数,接受轮索引数作为输入(整数,从 0 开始迭代) 然后返回一个学习速率作为输出(浮点数).

数据可视化 Tensorboard

tf.keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, 
									batch_size=32, write_graph=True,
									write_grads=False, write_images=False,
									embeddings_freq=0, 
									embeddings_layer_names=None,
								    embeddings_metadata=None, embeddings_data=None, 
								    update_freq='epoch')

简单自定义Callback LambdaCallback

tf.keras.callbacks.LambdaCallback(on_epoch_begin=None, on_epoch_end=None, 
										on_batch_begin=None, on_batch_end=None,
									    on_train_begin=None, on_train_end=None)
  • 2
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值