(转)原创文章详见:http://blog.csdn.net/sinat_29957455/article/details/78483631
一、模型文件的保存
在训练一个TensorFlow模型之后,我们可以将训练好的模型保存成文件,这样可以方便下一次对新的数据进行预测的时候直接加载训练好的模型即可获得结果,下面通过TensorFlow提供的tf.train.Saver函数,将一个模型保存成文件,一般习惯性的将TensorFlow的模型文件命名为*.ckpt文件。
- <span style="font-size:14px;">import tensorflow as tf
- if __name__ == "__main__":
- #定义两个变量
- a = tf.Variable(tf.constant(1.0,shape=[1],name="a"))
- b = tf.Variable(tf.constant(2.0,shape=[1],name="b"))
- c = a + b
- init = tf.initialize_all_variables()
- sess = tf.Session()
- sess.run(init)
- #声明一个保存
- saver = tf.train.Saver()
- saver.save(sess,"./model.ckpt")</span>
二、模型文件的读取
TensorFlow对于模型文件的读取方式也提供了几种方法,根据读取不同的文件来获取不同的信息。
1、加载model.ckpt文件来初始化变量
- <span style="font-size:14px;"> a = tf.Variable(tf.constant(3.0,shape=[1],name="a"))
- b = tf.Variable(tf.constant(4.0,shape=[1],name="b"))
- c = a + b
- saver = tf.train.Saver()
- sess = tf.Session()
- saver.restore(sess,"model.ckpt")
- print(sess.run(c))
- #[ 3.]</span>
2、加载持久化图获取全部变量
- <span style="font-size:14px;"> saver = tf.train.import_meta_graph("model.ckpt.meta")
- sess = tf.Session()
- saver.restore(sess,"model.ckpt")
- print(sess.run(tf.get_default_graph().get_tensor_by_name("a:0")))
- #[ 1.]
- print(sess.run(tf.get_default_graph().get_tensor_by_name("b:0")))
- #[ 2.]
- print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))
- #[ 3.]</span>
3、加载指定列表变量
- <span style="font-size:14px;"> a = tf.Variable(tf.constant(3.0,shape=[1],name="a"))
- b = tf.Variable(tf.constant(4.0,shape=[1],name="b"))
- c = a + b
- saver = tf.train.Saver([a,b])
- sess = tf.Session()
- saver.restore(sess,"model.ckpt")
- print(sess.run(a))
- #[ 1.]
- print(sess.run(b))
- #[ 2.]</span>
[[Node: _retval_Variable_1_0_0 = _Retval[T=DT_FLOAT, index=0, _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_1)]],使用一个没有初始化的变量。
4、加载变量名的重命名
tensorfow提供了一种方法可以修改加载模型中的变量名,通过tf.train.Saver(),带参的形式来修改变量名称。
- <span style="font-size:14px;"> #重新定义两个变量v1和v2
- v1 = tf.Variable(tf.constant(3.,shape=[1]),name="v1")
- v2 = tf.Variable(tf.constant(4.,shape=[1]),name="v2")
- #将模型中的变量名a重命名为v1,将模型中的变量名b重命名为v2
- save = tf.train.Saver({"a":v1,"b":v2})
- sess = tf.Session()
- save.restore(sess,"model.ckpt")
- print(sess.run(v1))
- print(sess.run(v2))</span>