TensorFlow中常见内置回调Callback

class BaseLogger:

计算每个epoch周期的平均指标,这个回调已经被自动应用在每个Keras模型,所以不需要手动设置

callbacks = tf.keras.callbacks.BaseLogger(
    stateful_metrics=None
)

model.fit(
    train_data,
    labels,
    epochs=5,
    batch_size=32,
    validation_split=0.2,
    callbacks=callbacks
)

class CSVLogger:

将每个epoch的评估及损失结果导入到一个CSV文件中

  • filename:CSV保存路径
  • separator:不同字段之间的分割符
  • append:是否在原来的文件基础之上追加
callbacks = tf.keras.callbacks.CSVLogger(
    filename='./res.log',
    separator=',',
    append=False
)

model.fit(
    train_data,
    labels,
    epochs=5,
    batch_size=32,
    validation_split=0.2,
    callbacks=callbacks
)

class EarlyStopping:

当一个被监控的指标停止提升的时候停止训练

  • monitor:需要监控的指标或者损失
  • min_delta:最小误差,只有两个epoch的评估值达到这个误差才会认为是一次变化,如果两次的误差小于min_delta则认为两次训练没有任何变化
  • patience:连续没有改进的epoch数,如果连续patience个epoch还没有改进,则停止训练
  • verbose:详细模式,用户打印控制台日志
  • mode:有三种模式,分别是minmaxauto,如果是min那么会判断如果监控的损失不在下降停止训练,如果是max,那么则发现监控的指标不在上升停止训练,如果是auto则会根据传进来的监控指标进行推断
  • baseline:监控指标的基线值,如果模型在基线上没有显示出改进,则训练将停止
  • restore_best_weights:是否从具有监控指标最佳值的epoch恢复模型权重
callbacks = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    min_delta=1e-3,
    patience=2,
    verbose=0,
    mode='min',
    baseline=None,
    restore_best_weights=False
)

model.fit(
    train_data,
    labels,
    epochs=5,
    batch_size=32,
    validation_split=0.2,
    callbacks=callbacks
)

class History:

将训练事件记录到history对象中,此回调会自动应用于每个 Keras 模型,history 对象由模型的 fit 方法返回。

模型训练后返回的history对象会包含训练时期每个epoch的精度或者损失值以及验证集的评估指标

class LearningRateScheduler:

学习率时间表

  • schedule:一个函数,它以epoch为索引(整数,从 0 开始索引)和当前学习率(浮点数)作为输入,并返回一个新的学习率作为输出(浮点数)。
  • verbose:是否打印学习更新情况
def scheduler(epoch, lr):
    if epoch < 10:
        return lr
    else:
        return lr * tf.math.exp(-0.1)


callbacks = tf.keras.callbacks.LearningRateScheduler(scheduler=scheduler,
                                                     verbose=1)

model.fit(
    train_data,
    labels,
    epochs=5,
    batch_size=32,
    validation_split=0.2,
    callbacks=callbacks
)

class ModelCheckpoint:

以某个频率保存 Keras 模型或模型权重的回调

  • filename:保存模型或者权重的路径
  • monitor:需要监测的损失或者评估指标
  • verbose:控制台输出状态
  • save_best_only:是否保存最好的模型
  • save_weights_only:是否只保存权重,否则是保存整个模型
  • mode:监控模式,minmaxauto,是按照监控的评估指标来定,如果是损失选择min,如果是准确率这种选择max,如果是auto会根据传入的monitor自动推断
  • save_freq:两种选择,分别是epochinteger,如果是epoch是每个epoch保存一次,如果是填写一个整数,代表每训练多少个批次保存一次
  • options:其它配置,用于保存模型或者参数
callbacks = tf.keras.callbacks.ModelCheckpoint(
    filename='./save_model',
    monitor='val_loss',
    verbose=1,
    save_best_only=False,
    save_weights_only=False,
    mode='auto',
    save_freq='epoch',
    options=None
)

model.fit(
    train_data,
    labels,
    epochs=5,
    batch_size=32,
    validation_split=0.2,
    callbacks=callbacks
)

class ProgbarLogger:

打印精度到标准输出

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

海洋 之心

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值