tensorflow2.0常用回调函数小结

经查看官方文档将常用回调函数做以下小结,目的是了解每个回调函数的作用与参数用法。

上图是tf2.0的全部回调函数,在这里介绍常用的4个回调函数:EarlyStopping,tensorboard,ModelCheckpoint,history。

1、tf.keras.callbacks.EarlyStopping

目的/作用:当监控的值停止变化时,提前结束训练。

定义:

tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto',
    baseline=None, restore_best_weights=False
)

由上面的代码段可以得知,当未自己手动设置monitor时,默认监控的是验证集的loss(val_loss)。

常用参数介绍:

monitor:监控的值。
min_delta:监视值的最小变化,即,绝对变化小于min_delta的情况,将视为没有变化
patience:在多少个epoch,监控的值没有变化后,将停止训练。(也就是连续多少个epoch,监控值的绝对变化小于min_delta)

示例:

callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)
# This callback will stop the training when there is no improvement in
# the validation loss for three consecutive epochs.
model.fit(data, labels, epochs=100, callbacks=[callback],
    validation_data=(val_data, val_labels))

2、tf.keras.callbacks.TensorBoard

作用:tensorflow的可视化工具

定义:

tf.keras.callbacks.TensorBoard(
    log_dir='logs', histogram_freq=0, write_graph=True, write_images=False,
    update_freq='epoch', profile_batch=2, embeddings_freq=0,
    embeddings_metadata=None, **kwargs
)

常用参数:

log_dir:将TensorBoard解析的日志文件保存到的目录路径。

其余用到再补充

示例:

logdir = os.path.join("callbacks")
if not os.path.exists(logdir):
    os.mkdir(logdir)

callbacks = [
    keras.callbacks.TensorBoard(logdir),]

history = model.fit(x_train_scaled, y_train, epochs=100,
                    validation_data=(x_valid_scaled, y_valid),
                    callbacks = callbacks)

tensorboard显示:

 

3、tf.keras.callbacks.ModelCheckpoint

作用:在每一次epoch后保存模型

定义:

tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='val_loss', verbose=0, save_best_only=False,
    save_weights_only=False, mode='auto', save_freq='epoch', **kwargs
)

常用参数:

filepath:字符串,保存模型文件的路径。

示例:

logdir = os.path.join("callbacks")
output_model_file = os.path.join(logdir,
                                 "fashion_mnist_model.h5")

callbacks = [
    keras.callbacks.ModelCheckpoint(output_model_file,
                                    save_best_only = True),#保存最好的模型,默认保存最近的
]

history = model.fit(x_train_scaled, y_train, epochs=100,
                    validation_data=(x_valid_scaled, y_valid),
                    callbacks = callbacks)

4、tf.keras.callbacks.History

这个回调函数会自动应用到每一个keras模型,History对象通过模型的fit方法得到返回。

history = model.fit(x_train_scaled, y_train, epochs=100,
                    validation_data=(x_valid_scaled, y_valid),
                    callbacks = callbacks)

 

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值