tensorflow笔记【7】深度学习-断点续训,保存模型
`
一、断点续训,保存模型(compile后面)
load_weights(l路径文件名)–告知文件存在哪里,直接读取已有模型的参数。
借助tensorflow给出的回调函数,直接保存参数和网络。
定义出存放模型的路径和文件名(.ckpt)
判断是不是已经有了索引表,知到是不是保存过模型参数
生成ckpt文件会同步生成索引表。
# -----------------------断点续训,保存模型--------------------
# 读取模型
checkpoint_save_path = 'checkpoint/checkpoint.ckpt'
if os.path.exists(checkpoint_save_path + '.index'):
print('------------load the model---------------')
model.load_weights(checkpoint_save_path)
# 保存模型
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True
)
ModelCheckpoint
keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1)
在每个训练期之后保存模型。
filepath 可以包括命名格式选项,可以由 epoch 的值和 logs 的键(由 on_epoch_end 参数传递)来填充。
例如:如果 filepath 是 weights.{
epoch:02d}-{
val_loss:.2f}.hdf5, 那么模型被保存的的文件名就会有训练轮数和验证损失。
**参数**
filepath: 字符串,保存模型的路径。
monitor: 被监测的数据。
verbose: 详细信息模式,0 或者 1 。
save_best_only: 如果 save_best_only=True, 被监测数据的最佳模型就不会被覆盖。
mode: {
auto, min, max} 的其中之一。 如果 save_best_only=True,那么是否覆盖保存文件的决定就取决于被监测数据的最大或者最小值。 对于 val_acc,模式就会是 max,而对于 val_loss,模式就需要是 min,等等。 在 auto 模式中,方向会自动从被监测的数据的名字中判断出来。
save_weights_only: 如果 True,那么只有模型的权重会被保存 (model.save_weights(filepath)), 否则的话,整个模型会被保存 (model.save(filepath))。
period: 每个检查点之间的间隔(训练轮数)。
详细代码:
# 1.导入相关模块---import
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Flatten, Dense
import os
# 2.指定数据集----(x_train,y_train),(x_test,y_test)