import glob
import os.path
import numpy as np
import tensorflow as tf
import random
from tensorflow.python.platform import gfile
from tensorflow.python.framework import graph_util
#保存模型
v1=tf.Variable(tf.constant(1.0,shape=[1]),name='v1')
v2=tf.Variable(tf.constant(2.0,shape=[1]),name='v2')
v3=tf.Variable(tf.constant(4.0,shape=[1]),name='v3')
result=tf.multiply(v1,v2,name='mul')
result1=tf.add(result,v3,name='add')
init_op=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
graph_def=tf.get_default_graph().as_graph_def()
output_graph_def=graph_util.convert_variables_to_constants(sess,graph_def,['add']) #跟add有关的计算节点都会保存,所以这#里包括了mul计算#
with tf.gfile.GFile('./model.pb','wb') as f:
f.write(output_graph_def.SerializeToString())
#加载模型
with tf.Session() as sess:
model_filename='./model.pb'
with gfile.FastGFile(model_filename,'rb') as f:
graph_def=tf.GraphDef()
graph_def.ParseFromString(f.read())
result,vv=tf.import_graph_def(graph_def,return_elements=['add:0','mul:0'])
print(sess.run(result))
print(sess.run(vv))