1、保存pb,一定要在模型训练完成之后再保存,否则保存的不是最终数据,具体的输出节点名字,要看自己的网络结构,或者通过打印模型节点输出的形式查看
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def,
['proj/logits', 'transition_params'])
with tf.gfile.FastGFile(self.model_path + 'model.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())
2、加载pb
with tf.Session() as sess:
'''Load model from *.pb'''
with gfile.FastGFile(intent_model_path, "rb") as f:
new_graph = tf.GraphDef()
new_graph.ParseFromString(f.read())
tf.import_graph_def(new_graph, name='')