根据是否需要多平台部署,可以选择仅保存模型参数还是同时保存参数和结构。
文章目录
仅保存模型参数
当只关注训练过程不考虑多平台部署的情况下,可以只保存模型参数,而恢复模型时先用模型源码恢复模型结构,再载入参数即可。
tf 模型
tf.train.Checkpoint
只保存模型的参数,不保存模型的计算过程,适合用于此种情况。tf.train.Checkpoint
可以保存与恢复 tf 中的大部分对象(所有包含 Checkpointable State
的对象),包括tf.keras.optimizer
、tf.Variable
、tf.keras.Layer
和 tf.keras.Model
实例都可以被保存和恢复。
# .ckpt
# save
Split = split.Split() # 模型源码恢复模型结构
checkpoint = tf.train.Checkpoint(model=Split) # 初始化,注意键值对
checkpoint.save(config.saved_models + 'split' + '.ckpt')
# restore
Split = split.Split() # 模型源码恢复模型结构
checkpoint = tf.train.Checkpoint(model=Split) # 使用一致的键值对再次实例化
checkpoint.restore(tf.train.latest_checkpoint(config.saved_models))
keras 模型
对于保存和恢复 keras 模型参数,可以选择保存为 h5 格式,或者是 tf 格式。
当使用 tf 格式时,与 checkpoint 保存与恢复方法一致。但即便如此,保存与恢复时却不能随意组合乱用,save_weights
只能和 load_weights
搭配保存和恢复,checkpoint.save
只能和 checkpoint.restore
搭配使用。官方更推荐直接用 checkpoint(Prefer tf.train.Checkpoint
over save_weights
for training checkpoints.)。
# .ckpt 不指定 h5 时
# save
Split = split.Split()
Split.save_weights(config.saved_models + 'split_' + str(epoch))
# load
Split = split.Split()
Split.load_weights(config.saved_models + 'split_1')
同时保存模型结构和参数
如果需要导出模型结构和参数,即不需要模型源码也能运行模型,可以使用 tf.saved_model.save
和 tf.saved_model.load
保存和恢复模型和参数。
这种一般存在于多平台部署的情况,比如已经训练好了一个模型,需要部署到服务器、移动端和嵌入式等,第一步往往就是将模型完成导出(序列化)为一系列标准格式的文件。
tf 模型
SavedModel 就是 tf 提供的统一模型导出格式,也是 tf2 中主要使用的。如此,可将这一格式作为中介,将训练好的模型部署在多种平台上。SavedModel 比 上面 checkpoint 更进一步,包含模型的完整信息(参数权值 + 计算图)。
keras 模型
非 keras Sequential / Functional
Keras 的模型都可以导出为 SavedModel 格式,但是通过继承 tf.keras.Model
类建立的 keras 模型,得使用 @tf.function
修饰需要导出为 SavedModel 格式的方法(加载模型后,预测/推断时不能直接用 model()
,而是使用 model.call()
)。
# .pb 格式
# save
Split = split.Split()
tf.saved_model.save(Split, config.saved_models)
# load
Split = tf.saved_model.load(config.saved_models)
Sequential / Functional
如果使用 Sequential 构建模型,可以使用如下方式保存 h5 格式的结构和参数,当不指定 .h5 后缀会保存为 .pb。
# save
model.save('model.h5')
'''
使用 .save 保存的信息包括:
- The model architecture, allowing to re-instantiate the model.
- The model weights.
- The state of the optimizer, allowing to resume training exactly where you left off.
'''
# load
model = tf.keras.models.load_model('model.h5')