A.保存为pb文件第一种方法
import tensorflow as tf
from tensorflow.python.framework import graph_util
with tf.Session(graph=tf.Graph()) as sess:
x=tf.placeholder(tf.int32,name='x')
y=tf.placeholder(tf.int32,name='y')
b=tf.Variable(1,name='b')
op=tf.add(tf.multiply(x,y),b,name='op')
sess.run(tf.global_variables_initializer())
feed_dict={x:10,y:4}
print(sess.run(op,feed_dict)) #查看运行结果
con_graph=graph_util.convert_variables_to_constants(sess,sess.graph_def,['op']) #必备工具graph_util,先将变量转换为常量
with tf.gfile.FastGFile('./model.pb',mode='wb') as f: #保存路径
f.write(con_graph.SerializeToString()) #常量序列化字符串
B.保存pb文件第二种方法
import tensorflow as tf
with tf.Session(graph=tf.Graph()) as sess:
x=tf.placeholder(tf.int32,name='x')
y=tf.placeholder(tf.int32,name='y')
b=tf.Variable(1,name='b')
op=tf.add(tf.multiply(x,y),b,name='op')
sess.run(tf.global_variables_initializer())
feed_dict={x:10,y:4}
# print(sess.run(op,feed_dict)) #查看运行结果
builder=tf.saved_model.builder.SavedModelBuilder('./modeldir')
builder.add_meta_graph_and_variables(sess,['cpu_server_1'])
builder.save()
A.读取pb文件,然后调用第一种方法
import tensorflow as tf
with tf.gfile.FastGFile('./model.pb',mode='rb') as f:
graph_def=tf.GraphDef() #图定义类对象
graph_def.ParseFromString(f.read()) #字符串转对象
with tf.Session() as sess:
tf.import_graph_def(graph_def,name='') #对象图导入sess
sess.run(tf.global_variables_initializer())
inputx=sess.graph.get_tensor_by_name('x:0')
inputy=sess.graph.get_tensor_by_name('y:0')
op=sess.graph.get_tensor_by_name('op:0')
print(sess.run(op,{inputx:5,inputy:5}))
B.读取pb文件,然后调用第二种方法
import tensorflow as tf
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.saved_model.loader.load(sess,tags=['cpu_server_1'],export_dir='./modeldir') #tags必须知道保存的信息cpu_server_1
inputx=sess.graph.get_operations() #获取变量操作
print(inputx)
列取未知pb文件变量
import tensorflow as tf
with tf.gfile.FastGFile(r'E:\faster_rcnn_resnet101_coco_11_06_2017\frozen_inference_graph.pb',mode='rb') as f:
graph_def=tf.GraphDef() #图定义类对象
graph_def.ParseFromString(f.read()) #字符串转对象
with tf.Session() as sess:
tf.import_graph_def(graph_def,name='') #对象图导入sess
sess.run(tf.global_variables_initializer())
inputx=sess.graph.get_operations() #获取变量操作
print(inputx)
部分结果