最近在做深度学习的工程应用部署,需要使用C++调用tensorflow的API使用训练好的模型进行预测,所以需要把checkpoint的文件转成pb文件做离线调用。
1. ckpt文件
2. 获得输入输出节点
转pb文件时需要知道确切的输出节点,调用时需要知道输入节点。由ckpt文件获取输入输出节点
2.1 使用python获得所有节点
import tensorflow as tf
input_checkpoint = './ckpt_weight_c1/iter_2500_dic_coeff_0.99'
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
graph = tf.get_default_graph() # 获得默认的图
for op in graph.get_operations(): # 打印出所有Graph中节点的名称。
print(op.name)
第一个是输入节点,最后一个是输出节点。
输入节点就是模型的输入,也就是定义的作为输入的那个placeholder的名称,有模型代码的话可以在代码中看。
2.2 由ckpt生成events log文件使用tensorboard查看
import tensorflow as tf
g = tf.Graph()
with g.as_default() as g:
tf.train.import_meta_graph('./ckpt_weight/iter_2500_dic_coeff_0.99.meta')
with tf.Session(graph=g) as sess:
file_writer = tf.summary.FileWriter(logdir='./graph', graph=g)
如何在win7下使用tensorboard可以看我的其他博客,有写的。
tensorboard显示的GRAPHS的模型的第一层是输入,最后一层是输出。
3. ckpt转成pb文件
import tensorflow as tf
from tensorflow.python.framework import graph_util
def freeze_graph(input_checkpoint, output_graph):
'''
:param input_checkpoint:ckpt模型的路径
:param output_graph: PB模型保存路径
:return:
'''
output_node_names = "g_/Sigmoid"
saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
graph = tf.get_default_graph() # 获得默认的图
with tf.Session() as sess:
saver.restore(sess, input_checkpoint)
output_graph_def = graph_util.convert_variables_to_constants(
sess=sess, input_graph_def=sess.graph_def, output_node_names=output_node_names.split(","))
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
使用pb进行预测在上一篇keras的h5转pb博客里有写。
以上就是在python中实现tensorflow的ckpt文件转pb文件的全过程。