回调函数(callbacks)(TensorBoard、EarlyStopping、ModelCheckPoint)
本文主要介绍tf.Keras.callbacks中的三种回调函数:TensorBoard、EarlyStopping、ModelCheckPoint。
-
TensorBoard:是Tensorflow自带的一个强大的可视化工具,也是一个web应用程序套件,它通过将tensorflow程序输出的日志文件的信息可视化使得tensorflow程序的理解、调试和优化更加简单高效。Tensorboard的可视化依赖于tensorflow程序运行输出的日志文件,因而tensorboard和tensorflow程序在不同的进程中运行。
参数详解:https://tensorflow.google.cn/api_docs/python/tf/keras/callbacks/TensorBoard -
EarlyStopping:为了获得性能良好的神经网络,网络定型过程中需要进行许多关于所用设置(超参数)的决策。超参数之一是定型周期(epoch)的数量:亦即应当完整遍历数据集多少次(一次为一个epoch)?如果epoch数量太少,网络有可能发生欠拟合(即对于定型数据的学习不够充分);如果epoch数量太多,则有可能发生过拟合(即网络对定型数据中的“噪声”而非信号拟合)。早停法旨在解决epoch数量需要手动设置的问题。它也可以被视为一种能够避免网络发生过拟合的正则化方法(与L1/L2权重衰减和丢弃法类似)。根本原因就是因为继续训练会导致测试集上的准确率下降。那继续训练导致测试准确率下降的原因猜测可能是1. 过拟合 2. 学习率过大导致不收敛。
参数详解:https://tensorflow.google.cn/api_docs/python/tf/keras/callbacks/EarlyStopping
monitor: 需要监视的量,val_loss,val_acc。
patience: 当early stop被激活(如发现loss相比上一个epoch训练没有下降),则经过patience个epoch后停止训练。
mode: ‘auto’,‘min’,'max’之一,在min模式训练,如果检测值停止下降则终止训练。在max模式下,当检测值不再上升的时候则停止训练。
min_delta:阈值。 -
ModelCheckPoint:在每次迭代之后保存模型。
参数详解:https://tensorflow.google.cn/api_docs/python/tf/keras/callbacks/ModelCheckpoint
Callback是在训练过程中调用,因此我们要在模型训练(model.fit)中添加
假设你已经选择好模型,并构建好网络以及编译后,到了训练模型这一步骤,使用回调函数。如果你不懂之前的步骤请参考:https://blog.csdn.net/qq_40913465/article/details/104249124
代码示例:
#使用回调函数
#tensorBoard、EarlyStopping、ModelCheckPoint
logdir = os.path.join("callbacks")
if not os.path.exists(logdir):
os.mkdir(logdir)
output_model_file = os.path.join(logdir,"fashion_mnist_model.h5")
callbacks = [
keras.callbacks.TensorBoard(log_dir=logdir), #log_dir将输出的日志保存在所要保存的路径中
keras.callbacks.ModelCheckpoint(output_model_file, save_best_only = True),
keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3),
]
#训练模型会,返回一个结果保存在history中
history = model.fit(x_train_scaled, y_train, epochs=50,
validation_data=(x_valid_scaled, y_valid),
callbacks=callbacks) #使用回调函数
结果展示:
这是运行后生成的文件,callbacks为tensorboard生成的文件夹,fashion_mnist_model为modelcheckpoint生成的文件,这样我们就可以查看相应的文件了,如果你想在win下可视化数据,请参考win下怎么查看tensorboard