不同情景下 tf/keras 模型的保存与恢复

根据是否需要多平台部署,可以选择仅保存模型参数还是同时保存参数和结构。

仅保存模型参数

当只关注训练过程不考虑多平台部署的情况下,可以只保存模型参数,而恢复模型时先用模型源码恢复模型结构,再载入参数即可。

tf 模型

tf.train.Checkpoint 只保存模型的参数,不保存模型的计算过程,适合用于此种情况。tf.train.Checkpoint 可以保存与恢复 tf 中的大部分对象(所有包含 Checkpointable State 的对象),包括tf.keras.optimizertf.Variabletf.keras.Layertf.keras.Model 实例都可以被保存和恢复。

# .ckpt
# save
Split = split.Split()	# 模型源码恢复模型结构
checkpoint = tf.train.Checkpoint(model=Split)	# 初始化,注意键值对
checkpoint.save(config.saved_models + 'split' + '.ckpt')

# restore
Split = split.Split()	# 模型源码恢复模型结构
checkpoint = tf.train.Checkpoint(model=Split)	# 使用一致的键值对再次实例化
checkpoint.restore(tf.train.latest_checkpoint(config.saved_models))

keras 模型

对于保存和恢复 keras 模型参数,可以选择保存为 h5 格式,或者是 tf 格式。

当使用 tf 格式时,与 checkpoint 保存与恢复方法一致。但即便如此,保存与恢复时却不能随意组合乱用,save_weights 只能和 load_weights 搭配保存和恢复,checkpoint.save 只能和 checkpoint.restore 搭配使用。官方更推荐直接用 checkpoint(Prefer tf.train.Checkpoint over save_weights for training checkpoints.)。

# .ckpt 不指定 h5 时
# save
Split = split.Split()
Split.save_weights(config.saved_models + 'split_' + str(epoch))

# load
Split = split.Split()
Split.load_weights(config.saved_models + 'split_1')

同时保存模型结构和参数

如果需要导出模型结构和参数,即不需要模型源码也能运行模型,可以使用 tf.saved_model.savetf.saved_model.load 保存和恢复模型和参数。

这种一般存在于多平台部署的情况,比如已经训练好了一个模型,需要部署到服务器、移动端和嵌入式等,第一步往往就是将模型完成导出(序列化)为一系列标准格式的文件。

tf 模型

SavedModel 就是 tf 提供的统一模型导出格式,也是 tf2 中主要使用的。如此,可将这一格式作为中介,将训练好的模型部署在多种平台上。SavedModel 比 上面 checkpoint 更进一步,包含模型的完整信息(参数权值 + 计算图)。

keras 模型

非 keras Sequential / Functional

Keras 的模型都可以导出为 SavedModel 格式,但是通过继承 tf.keras.Model 类建立的 keras 模型,得使用 @tf.function 修饰需要导出为 SavedModel 格式的方法(加载模型后,预测/推断时不能直接用 model(),而是使用 model.call())。

# .pb 格式
# save
Split = split.Split()
tf.saved_model.save(Split, config.saved_models)

# load
Split = tf.saved_model.load(config.saved_models)
Sequential / Functional

如果使用 Sequential 构建模型,可以使用如下方式保存 h5 格式的结构和参数,当不指定 .h5 后缀会保存为 .pb。

# save
model.save('model.h5')		
'''
使用 .save 保存的信息包括:
- The model architecture, allowing to re-instantiate the model.
- The model weights.
- The state of the optimizer, allowing to resume training exactly where you left off.
'''
# load
model = tf.keras.models.load_model('model.h5')
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值