TensorFlow2.x——回调函数(callbacks)(TensorBoard、EarlyStopping、ModelCheckPoint)

回调函数(callbacks)(TensorBoard、EarlyStopping、ModelCheckPoint)

本文主要介绍tf.Keras.callbacks中的三种回调函数:TensorBoard、EarlyStopping、ModelCheckPoint。

  • TensorBoard:是Tensorflow自带的一个强大的可视化工具,也是一个web应用程序套件,它通过将tensorflow程序输出的日志文件的信息可视化使得tensorflow程序的理解、调试和优化更加简单高效。Tensorboard的可视化依赖于tensorflow程序运行输出的日志文件,因而tensorboard和tensorflow程序在不同的进程中运行。
    参数详解:https://tensorflow.google.cn/api_docs/python/tf/keras/callbacks/TensorBoard

  • EarlyStopping:为了获得性能良好的神经网络,网络定型过程中需要进行许多关于所用设置(超参数)的决策。超参数之一是定型周期(epoch)的数量:亦即应当完整遍历数据集多少次(一次为一个epoch)?如果epoch数量太少,网络有可能发生欠拟合(即对于定型数据的学习不够充分);如果epoch数量太多,则有可能发生过拟合(即网络对定型数据中的“噪声”而非信号拟合)。早停法旨在解决epoch数量需要手动设置的问题。它也可以被视为一种能够避免网络发生过拟合的正则化方法(与L1/L2权重衰减和丢弃法类似)。根本原因就是因为继续训练会导致测试集上的准确率下降。那继续训练导致测试准确率下降的原因猜测可能是1. 过拟合 2. 学习率过大导致不收敛。
    参数详解:https://tensorflow.google.cn/api_docs/python/tf/keras/callbacks/EarlyStopping
    monitor: 需要监视的量,val_loss,val_acc。
    patience: 当early stop被激活(如发现loss相比上一个epoch训练没有下降),则经过patience个epoch后停止训练。
    mode: ‘auto’,‘min’,'max’之一,在min模式训练,如果检测值停止下降则终止训练。在max模式下,当检测值不再上升的时候则停止训练。
    min_delta:阈值。

  • ModelCheckPoint:在每次迭代之后保存模型。
    参数详解:https://tensorflow.google.cn/api_docs/python/tf/keras/callbacks/ModelCheckpoint

Callback是在训练过程中调用,因此我们要在模型训练(model.fit)中添加

假设你已经选择好模型,并构建好网络以及编译后,到了训练模型这一步骤,使用回调函数。如果你不懂之前的步骤请参考:https://blog.csdn.net/qq_40913465/article/details/104249124

代码示例:

#使用回调函数
#tensorBoard、EarlyStopping、ModelCheckPoint
logdir = os.path.join("callbacks")
if not os.path.exists(logdir):
    os.mkdir(logdir)
output_model_file = os.path.join(logdir,"fashion_mnist_model.h5")

callbacks = [
    keras.callbacks.TensorBoard(log_dir=logdir), #log_dir将输出的日志保存在所要保存的路径中
    keras.callbacks.ModelCheckpoint(output_model_file, save_best_only = True), 
    keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3),
]
#训练模型会,返回一个结果保存在history中
history = model.fit(x_train_scaled, y_train, epochs=50, 
                    validation_data=(x_valid_scaled, y_valid), 
                    callbacks=callbacks) #使用回调函数

结果展示:
在这里插入图片描述
这是运行后生成的文件,callbacks为tensorboard生成的文件夹,fashion_mnist_model为modelcheckpoint生成的文件,这样我们就可以查看相应的文件了,如果你想在win下可视化数据,请参考win下怎么查看tensorboard

这行代码导入了 TensorFlow 的 Keras 库中的三个回调函数:`ReduceLROnPlateau`、`ModelCheckpoint` 和 `EarlyStopping`。这三个回调函数都可以在训练神经网络时起到重要的作用。 `ReduceLROnPlateau` 回调函数用于在训练过程中动态地调整学习率,以便更好地训练模型。该回调函数可以设置监控的指标、调整学习率的因子、调整学习率的频率等参数。例如: ```python reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, min_lr=0.0001) ``` 其中,`monitor` 是监控的指标,例如 validation loss,`factor` 是调整学习率的因子,即将学习率乘以该因子,`patience` 是连续多少个 epoch 指标没有提升时进行调整,`min_lr` 是最小学习率,即学习率不会低于该值。 `ModelCheckpoint` 回调函数用于定期保存训练过程中的模型权重,以便在训练过程中出现中断或意外情况时,可以继续训练或者恢复最佳模型。该回调函数可以设置保存模型的路径、保存的文件名、保存的频率、是否只保存最佳模型等参数。例如: ```python checkpoint = ModelCheckpoint('model.h5', monitor='val_loss', verbose=1, save_best_only=True, mode='min') ``` 其中,`model.h5` 是保存模型的路径和文件名,`monitor` 是监控的指标,例如 validation loss,`verbose` 是输出保存模型的信息,`save_best_only` 表示只保存最佳模型,`mode` 表示监控指标的模式,例如最小化指标。 `EarlyStopping` 回调函数用于在训练过程中检测验证集的性能是否有提升,如果连续若干个 epoch 验证集的指标没有提升,则停止训练。该回调函数可以设置检测的监控指标、检测的循环周期、最大等待轮数等参数。例如: ```python earlystop = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1, mode='min') ``` 其中,`monitor` 是监控的指标,例如 validation loss,`min_delta` 是最小变化量,即当指标变化小于该值时认为没有提升,`patience` 是最大等待轮数,即当连续多少个 epoch 没有提升时停止训练,`verbose` 是输出停止训练的信息,`mode` 表示监控指标的模式,例如最小化指标。 在训练过程中,可以将这三个回调函数传递给 `fit` 函数,例如: ```python model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=100, batch_size=32, callbacks=[reduce_lr, checkpoint, earlystop]) ``` 这样就可以在训练过程中动态调整学习率、保存模型和早期停止训练。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值