概述
需要从ckpt文件生成固化的PB文件,给生成现场用。现在我探索下不依赖代码里的网络结构、仅仅用CKPT文件来生成PB文件。
转化
训练生成的ckpt文件目录如下:
$ ls ./ckpt0507/
checkpoint graph.pbtxt model.ckpt-3251.data-00000-of-00001 model.ckpt-6500.index model.ckpt-6501.meta
events.out.tfevents.1557215392.amax model.ckpt-3250.data-00000-of-00001 model.ckpt-3251.index model.ckpt-6500.meta model.ckpt-9750.data-00000-of-00001
events.out.tfevents.1557217802.amax model.ckpt-3250.index model.ckpt-3251.meta model.ckpt-6501.data-00000-of-00001 model.ckpt-9750.index
events.out.tfevents.1557219123.amax model.ckpt-3250.meta model.ckpt-6500.data-00000-of-00001 model.ckpt-6501.index model.ckpt-9750.meta
转化PB文件的代码如下:
import tensorflow as tf
import os
def read_graph_from_ckpt(ckpt_path, out_pb_path, output_name ):
# 从meta文件加载网络结构
saver = tf.train.import_meta_graph(ckpt_path+'.meta',clear_devices=True)
graph = tf.get_default_graph()
with tf.Session( graph=graph) as sess:
sess.run(tf.global_variables_initializer())
# 从ckpt加载参数
saver.restore(sess, ckpt_path)
output_tf =graph.get_tensor_by_name(output_name)
# 固化
pb_graph = tf.graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [output_tf.op.name])
# 保存
with tf.gfile.FastGFile(out_pb_path, mode='wb') as f:
f.write(pb_graph.SerializeToString())
read_graph_from_ckpt('./ckpt0507/model.ckpt-9750', './idcard_seg.pb', 'decoder/upsampling_2_logits/conv_1x1/BiasAdd:0')
调用read_graph_from_ckpt()主要需要知道输出tensor节点的名称。我是通过Netron这个神经网络可视化工具查看的节点名称。
总结
关于如何导出训练模型,我之前的一篇文章《tensorflow 20:搭网络、导出模型、运行模型》也有涉及。
无论是那篇文章还是本文,都是调用convert_variables_to_constants()完成的模型固化。调用convert_variables_to_constants的前提就是session已经绑定了一个计算图,这个计算图可以是刚刚训练的,也可以是从磁盘文件读取的,或者其它格式文件加载到内存中的,然后这个计算图被固化保存到磁盘上的PB文件。