tensorflow 实现模型持久化
本文主要介绍如何用tensorflow来实现训练好的模型的持久化以及模型的引用。
import tensorflow as tf
saver=tf.train.Saver() #用来创建一个持久化类
在训练的时候,可以设置迭代固定的次数然后保存模型
save.save(sess,'mnist_fenlei_model/',global_step=global_step)
‘mnist_fenlei_model/’,是你保存模型的位置,这里 global_step 一定要写,不然后面引用模型的时候会有错误。
在测试程序中,用以下代码实现模型引用:
ckpt=tf.train.get_checkpoint_state('mnist_fenlei_model') #自动找到最新的模型。
saver=tf.train.Saver()
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
print('no ckpt ')
这样就可以用训练好的模型实现测试或者其他应用。