目录
保存模型权重
1. 使用回调函数保存
2. 手动保存
这种是在model.fit时传入保存checkpoint的回调函数。使用的回调函数是tf.keras.callbacks.ModelCheckpoint。需要传入checkpoint保存路径,可以设置保存频率。
checkpoint_path = 'training_1/cp-{epoch:04d}.ckpt'
# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath = checkpoint_path,
save_weights_only = True,
verbose = 1,
save_freq = 'epoch'
)
checkpoint_path里面的{epoch:04d}是为了使得不同epoch保存时,文件名称有差异。
save_weights_only为True时,相当于model.save_weights。为False时相当于model.save。
save_freq为‘epoch’时,在每个epoch结束时进行一次存储。当save_freq为整数N时,表示迭代 N个batch之后,进行一次保存。
使用save_weights_only为True保存的ckpt,在加载时通过model.load_weights(checkpoint_path)完成模型参数加载。
手动保存ckpt文件,使用model.save_weights(checkpoint_path)方法即可。