将计算图中 的变量及取值通过变量的方式保存
graph_util 模块中的函数
import tensorflow as tf
from tensorflow.python.framwork import graph_util
v1 = tf.variable(tf.constant(1.0, shape = [1]), name = 'v1')
v2 = tf.variable(tf.constant(2.0, shape = [2]), name = 'v2')
result = v1 + v2
init_op = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init_op)
graph_def = tf.get_default_graph().as_graph_def()
#导出当期计算图的GraphDef 部分, 这部分可以完成 从 输入层到输出层的计算过程
#将图中的变量及其取值转换成 常量, 同时将图中不必要的节点去掉。
#最后一个参数"add",给出了需要保存的节点名称。add是上面定义的两个变量相加的操作。注意这里给出的是计算节点名称,所以后面没有:0
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])
#将导出的模型存入文件
with tf.gfile.GFile("/path/to/model/combined_model.pb", "wb") as f:
f.write(output_graph_def.SerializeToString())