转 ,原文详见: http://blog.csdn.net/sinat_29957455/article/details/78511119
前面介绍了通过使用tf.train.Saver函数来保存TensorFlow程序的参数,但是,在使用tf.train.Saver函数保存模型文件的时候,是保存所有的参数信息,而有些时候我们并不需要所有的参数信息。我们只需要知道神经网络的输入层经过前向传播计算得到输出层即可,所以在保存的时候,我们也不需要保存所有的参数,以及变量的初始化、模型保存等辅助节点信息与迁移学习类似。之前使用tf.train.Saver函数保存模型文件的时候会产生多个文件,它将变量的取值和计算图结构分成了不同的文件存储。TensorFlow提供了另一种保存模型文件的方法,将计算图保存在一个文件中。
1、模型文件的保存
- import tensorflow as tf
- from tensorflow.python.framework import graph_util
- from tensorflow.python.platform import gfile
- if __name__ == "__main__":
- a = tf.Variable(tf.constant(5.,shape=[1]),name="a")
- b = tf.Variable(tf.constant(6.,shape=[1]),name="b")
- c = a + b
- init = tf.initialize_all_variables()
- sess = tf.Session()
- sess.run(init)
- #导出当前计算图的GraphDef部分
- graph_def = tf.get_default_graph().as_graph_def()
- #保存指定的节点,并将节点值保存为常数
- output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['add'])
- #将计算图写入到模型文件中
- model_f = tf.gfile.GFile("model.pb","wb")
- model_f.write(output_graph_def.SerializeToString())
2、模型文件的读取
- sess = tf.Session()
- #将保存的模型文件解析为GraphDef
- model_f = gfile.FastGFile("model.pb",'rb')
- graph_def = tf.GraphDef()
- graph_def.ParseFromString(model_f.read())
- c = tf.import_graph_def(graph_def,return_elements=["add:0"])
- print(sess.run(c))
- #[array([ 11.], dtype=float32)]
版权声明:本文为博主原创文章,未经博主允许不得转载。 http://blog.csdn.net/sinat_29957455/article/details/78511119