显示pb模型节点信息:
import tensorflow as tf
import os
model_dir = './'
model_name = 'ocr.pb'
tf.reset_default_graph()
def create_graph():
with tf.gfile.FastGFile(os.path.join(model_dir, model_name), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
create_graph()
for node in tf.get_default_graph().as_graph_def().node:
# print(node)
print(node.op) # Placeholder
print(node.name)
print(node.attr)
inputs = [str(in_name) for in_name in node.input]
break
部分转载自:TensorFlow查看输入节点和输出节点名称_miao0967020148的专栏-CSDN博客
保存简单网络为pb文件
import tensorflow as tf
import numpy as np
data_shape = [1, 3, 32, 32]
filter_shape = [3, 3, 3, 64]
strides = [1, 1, 1, 1]
padding = "SAME"
dilations = [1, 1, 1, 1]
dtype = "float32"
data_format = "NCHW"
input = tf.placeholder(dtype=dtype, shape=data_shape)
filter = tf.placeholder(dtype=dtype, shape=filter_shape)
np.random.seed(0)
input_np = np.array(np.random.rand(*data_shape), dtype=dtype)
filter_np = np.array(np.random.rand(*filter_shape), dtype=dtype)
conv1_out = tf.nn.conv2d(input, filter, strides=strides, padding=padding,
data_format=data_format, dilations=dilations, name="conv2d_1")
def freeze_graph(output_node_names, output_graph_name):
graph = tf.get_default_graph()
input_graph_def = graph.as_graph_def()
pbtxt_model_name = output_graph_name[:-3]+".pbtxt"
tf.io.write_graph(input_graph_def, './', pbtxt_model_name)
with tf.Session() as sess:
output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(sess=sess, input_graph_def=input_graph_def,
output_node_names=output_node_names.split(","))
with tf.gfile.GFile(output_graph_name, "wb") as f:
f.write(output_graph_def.SerializeToString())
with tf.Session() as session:
result1 = session.run(conv1_out, feed_dict={input: input_np, filter: filter_np})
print(result1)
freeze_graph("conv2d_1", "conv2d_test.pb")
# graph = tf.get_default_graph()
# graph_def = graph.as_graph_def()
graph_def = sess.graph_def
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess=sess,
input_graph_def=input_graph_def,
output_node_names=output_names)
with tf.gfile.GFile(dump_graph_path, "wb") as f:
f.write(output_graph_def.SerializeToString())
pb pbtxt读取
graph_def = tf.compat.v1.GraphDef()
# pb
with tf.io.gfile.GFile(model_path, "rb") as f:
graph_def.ParseFromString(f.read())
# pbtxt
with open(model_path, "r") as pf:
text_format.Parse(pf.read(), graph_def)
tf2 frozen graph推理
参考:https://github.com/tensorflow/ngraph-bridge/blob/master/examples/infer_image.py
import time
import tensorflow as tf
import numpy as np
import os
import pdb
import sys
data_shape = (1, 512, 512, 3)
dtype = np.float32
model_file = "./model.pb"
test_data = np.random.random(data_shape).astype(dtype)
def load_graph(model_file):
graph = tf.Graph()
graph_def = tf.compat.v1.GraphDef()
with open(model_file, "rb") as f:
graph_def.ParseFromString(f.read())
with graph.as_default():
tf.import_graph_def(graph_def)
return graph
graph = load_graph(model_file)
input_name = "import/" + "input_1"
output_name = "import/" + "output_1"
input_operation = graph.get_operation_by_name(input_name)
output_operation = graph.get_operation_by_name(output_name)
print('start infer ............................')
with tf.compat.v1.Session(graph=graph) as sess:
# Warmup
results = sess.run(output_operation.outputs[0],
{input_operation.outputs[0]: test_data})
results = np.squeeze(results)
print(results)