一、保存:
graph_util.convert_variables_to_constants 可以把当前session的计算图串行化成一个字节流(二进制),这个函数包含三个参数:参数1:当前活动的session,它含有各变量
参数2:GraphDef 对象,它描述了计算网络
参数3:Graph图中需要输出的节点的名称的列表
返回值:精简版的GraphDef 对象,包含了原始输入GraphDef和session的网络和变量信息,它的成员函数SerializeToString()可以把这些信息串行化为字节流,然后写入文件里:
constant_graph = graph_util.convert_variables_to_constants( sess, sess.graph_def , ['sum_operation'] )
with open( pbName, mode='wb') as f:
f.write(constant_graph.SerializeToString())
需要指出的是,如果原始张量(包含在参数1和参数2中的组成部分)不参与参数3指定的输出节点列表所指定的张量计算的话,这些张量将不会存在返回的GraphDef对象里,也不会被串行化写入pb文件。
二、恢复:
恢复时,创建一个GraphDef,然后从上述的文件里加载进来,接着输入到当前的session:
graph0 = tf.GraphDef()
with open( pbName, mode='rb') as f:
graph0.ParseFromString( f.read() )
tf.import_graph_def( graph0 , name = '' )
三、代码:
import tensorflow as tf
from tensorflow.python.framework import graph_util
pbName = 'graphA.pb'
def graphCreate() :
with tf.Session() as sess :
var1 = tf.placeholder ( tf.int32 , name='var1' )
var2 = tf.