tensorboard = tf.keras.callbacks.TensorBoard(log_dir=self.log_dir, histogram_freq=1)
callbacks = [
# This callback saves a SavedModel every epoch
# We include the current epoch in the folder name.
keras.callbacks.ModelCheckpoint(
filepath=self.checkpoint_dir + "/ckpt-{epoch}", save_freq="epoch" # save_freq 保存频率,每一个epoch保存一次
),
EarlyStopping(monitor='loss', patience=10, min_delta=0.0003),# patience: Number of epochs with no improvement after which training will be stopped.
# min_delta: Minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement.
tensorboard
]
1. ModelCheckpoint 备份每一个epoch的下的模型
保存每一个epoch的结果,如果程序终端,下次可以从最新的一个epoch加载模型,继续训练
checkpoint,pb 格式的模型 https://zhuanlan.zhihu.com/p/32887066
pb 格式的模型可以让创建模型与使用模型的解耦
加载
fin_model = './ckpt/ckpt-42'
models.load_model(
fin_model,
custom_objects=MyEvaluate.metric_json)
2. EarlyStopping 模型提升不明显的时候,提前终止训练
EarlyStopping(monitor='loss', patience=10, min_delta=0.0003)
3. TensorBoard 训练时进行监控
官方教程:https://www.tensorflow.org/tensorboard/get_started
知乎总结:https://zhuanlan.zhihu.com/p/59986254
在TensorFlow的程序里将相关的events等以log的形式保存,在运行TensorBoard后自动加载log文件并以良好的图表呈现在web页面中。方便开发者查看训练过程和结果数据。
监控的内容分为训练和验证
子文件夹
可视化之后可以查看model的图形化展示