程序一:ckpt转pb
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
# 模型参数固化ckpt转pb
def freeze_graph(input_meta,input_checkpoint, output_graph):
'''
:param input_checkpoint:
:param output_graph: PB模型保存路径
:return:
'''
# 指定输出的节点名称,该节点名称必须是原模型中存在的节点
output_node_names = "XXXXX"
saver = tf.train.import_meta_graph(input_meta, clear_devices=True) # + '.meta'
graph = tf.get_default_graph() # 获得默认的图
input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
with tf.Session() as sess:
saver.restore(sess, input_checkpoint) # 恢复图并得到数据
output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
sess=sess,
input_graph_def=input_graph_def, # 等于:sess.graph_def
output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开
with gfile.GFile(output_graph, "wb") as f: # 保存模型
f.write(output_graph_def.SerializeToString()) # 序列化输出
print("%d ops in the final graph." % len(output_graph_def.node)) # 得到当前图有几个操作节点
程序二:测试是否转对了
# 测试
def testPb():
'''
:param pb_path:pb文件的路径
:param image_path:测试图片的路径
:return:
'''
pb_path = "XXXXX.pb"
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
if (os.path.isfile(pb_path)):
with open(pb_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
tf.import_graph_def(output_graph_def, name = "")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 定义输入的张量名称,对应网络结构的输入张量
input= tf.get_default_graph().get_tensor_by_name("input:0")
is_train = tf.get_default_graph().get_tensor_by_name("is_train:0")
# 定义输出的张量名称
output_tensor_name = sess.graph.get_tensor_by_name("XXXXXXX:0")
out = sess.run(output_tensor_name, feed_dict={input: XXX,
is_train : False})
print("output:{}".format(out))
其他:
可能会出现错误:
ValueError: Input 0 of node XXXXXXXXXXX/Switch was passed float from XXXXXXXXXXXXXxBathNormalXXXXXXX:0 incompatible with expected float_ref.
原因,转pb的时候BN层是float_ref,而转pb后为float
程序上可以做如下修改
程序二
# 测试
def testPb():
'''
:param pb_path:pb文件的路径
:param image_path:测试图片的路径
:return:
'''
pb_path = "XXXXX.pb"
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
if (os.path.isfile(pb_path)):
with open(pb_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
for node in output_graph_def.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in range(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr:
del node.attr['use_locking']
tf.import_graph_def(output_graph_def, name = "")
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 定义输入的张量名称,对应网络结构的输入张量
input= tf.get_default_graph().get_tensor_by_name("input:0")
is_train = tf.get_default_graph().get_tensor_by_name("is_train:0")
# 定义输出的张量名称
output_tensor_name = sess.graph.get_tensor_by_name("XXXXXXX:0")
out = sess.run(output_tensor_name, feed_dict={input: XXX,
is_train : False})
print("output:{}".format(out))
测试输出与未转化前完全一致,end!