TensorFlow 模型保存与恢复
欢迎关注简书:公输睚信
上一篇文章 TensorFlow 训练 CNN 分类器 中说明了训练简单 CNN 模型的整个过程,并在训练结束后使用 .save
函数来保存训练的结果,其后通过使用 tf.train.import_meta_graph
和 .restore
函数来导入模型进行推断。本文承接上文,对模型保存与恢复做一个总结。
总的来说,模型在保存和恢复时最重要的是留下数据接口,方便使用时传入数据和获取结果。TensorFlow 中常用的模型保存格式为 .ckpt 和 .pb,下面分别进行详细说明。
一、ckpt 格式模型保存与恢复
.ckpt 格式保存与恢复都很简单,具体可参考 TensorFlow 训练 CNN 分类器。
1. ckpt 格式模型保存
inputs = tf.placeholder(tf.float32, shape=[None, ···], name='inputs') <-- 入口
···
prediction = tf.nn.softmax(logits, name='prediction') <-- 出口(具体输出依情况而定,下同)
···
saver = tf.train.Saver()
···
with tf.Session() as sess:
··· <-- 训练过程
saver.save(sess, './xxx/xxx.ckpt') <-- 模型保存
如上述代码所示,假设你定义了一个 TensorFlow 模型,数据入口由占位符 inputs
给