一、首先介绍tensorflow持久化的工作原理,持久化代码实现。
1.1使用tf.train.Saver类,以下代码给出了保存tensorflow计算图的方法。
import tensorflow as tf
v1=tf.Variable(tf.constant(1.0,shape=[1]),name="v1")
v2=tf.Variable(tf.constant(2.0,shape=[1]),name="v2")
result=v1+v2
init_op=tf.initialize_all_variables()
saver=tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
saver.save(sess,"D:/gj20170720/model.ckpt")
1.2加载已保存的模型的方法
import tensorflow as tf
__author__ = 'casgj'
v1=tf.Variable(tf.constant(1.0,shape=[1]),name="v1")
v2=tf.Variable(tf.constant(2.0,shape=[1]),name="v2")
result=v1 + v2
saver=tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess,"D:/gj20170720/model.ckpt")
print(sess.run(result))
如果不希望重复定义图上的运算,也可以:
import tensorflow as tf
__author__ = 'casgj'
saver=tf.train.import_meta_graph(("D:/gj20170720/model.ckpt.meta"))
with tf.Session() as sess:
saver.restore(sess,"D:/gj20170720/model.ckpt")
print(sess.run
1.3变量重命名的使用
__author__ = 'casgj'
import tensorflow as tf
v1=tf.Variable(tf.constant(1.0,shape=[1]),name="other-v1")
v2=tf.Variable(tf.constant(2.0,shape=[1]),name="other-v2")
saver=tf.train.Saver({
"v1":v1,"v2":v2})
1.4保存滑动平均模型的运用
_author__ = 'casgj'
import tensorflow as tf
v=tf.Variable(0,dtype=tf.float32,name="v")
for variables in tf.all_variables():
print(variables.name)
ema=tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op=ema.apply(tf.all_variables())
for variables in tf.all_variables():
print(variables.name)
saver=tf.train.Saver()
with tf.Session() as sess:
init_op=tf.initialize_all_variables()
sess.run(init_op)
sess.run(tf.assign(v,10))
sess.run(maintain_averages_op)
saver.save(sess,