Tensorflow训练的模型,如果想使用TensorFlow Serving进行部署,需要将ckpt模型转换为pb模型。
一、模型格式转变
1、原文件格式:
2、新文件格式:
二、模型转化代码:
做3个地方修改即可。
1、需要结合自己网络结构的输入输出参数进行修改
2、定义模型的输入输出,建立调用接口与tensor签名之间的映射
3、设置原模型目录、新模型目录、版本号
#coding:utf-8
import sys, os, io
import tensorflow as tf
def restore_and_save(input_checkpoint, export_path_base):
checkpoint_file = tf.train.latest_checkpoint(input_checkpoint)
graph = tf.Graph()
with graph.as_default():
session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
sess = tf.Session(config=session_conf)
with sess.as_default():
# 载入保存好的meta graph,恢复图中变量,通过SavedModelBuilder保存可部署的模型
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess, checkpoint_file)
print(graph.get_name_scope())
for node in graph.as_graph_def().node:
print('node.name:', node.name)
export_path_base = export_path_base
export_path = os.path.join(
tf.compat.as_bytes(export_path_base),
tf.compat.as_bytes(str(count)))
print('Exporting trained model to', export_path)
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
# 建立签名映射,需要包括计算图中的placeholder(ChatInputs, SegInputs, Dropout)和我们需要的结果(project/logits,crf_loss/transitions)
"""
build_tensor_info:建立一个基于提供的参数构造的TensorInfo protocol buffer,
输入:tensorflow graph中的tensor;
输出:基于提供的参数(tensor)构建的包含TensorInfo的protocol buffer
get_operation_by_name:通过name获取checkpoint中保存的变量,能够进行这一步的前提是在模型保存的时候给对应的变量赋予name
"""
#1.需要结合自己网络结构的输入输出参数进行修改
left_inputs =tf.saved_model.utils.build_tensor_info(graph.get_operation_by_name("left").outputs[0])
right_inputs =tf.saved_model.utils.build_tensor_info(graph.get_operation_by_name("right").outputs[0])
#grcp调用方式会使用到,http方式不会使用的
output_prob =tf.saved_model.utils.build_tensor_info(graph.get_operation_by_name("output_prob").outputs[0])
"""
signature_constants:SavedModel保存和恢复操作的签名常量。
在序列标注的任务中,这里的method_name是"tensorflow/serving/predict"
"""
# 2.定义模型的输入输出,建立调用接口与tensor签名之间的映射
labeling_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={
"leftinputs":
left_inputs,
"rightinputs":
right_inputs,
},
outputs={
"output":
output_prob
},
method_name="tensorflow/serving/predict"))
"""
tf.group : 创建一个将多个操作分组的操作,返回一个可以执行所有输入的操作
"""
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
"""
add_meta_graph_and_variables:建立一个Saver来保存session中的变量,
输出对应的原图的定义,这个函数假设保存的变量已经被初始化;
对于一个SavedModelBuilder,这个API必须被调用一次来保存meta graph;
对于后面添加的图结构,可以使用函数 add_meta_graph()来进行添加
"""
# 建立模型名称与模型签名之间的映射
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
# 保存模型的方法名,与客户端的request.model_spec.signature_name对应
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
labeling_signature},
legacy_init_op=legacy_init_op)
builder.save()
print("Build Done")
#3.设置原模型目录、新模型目录、版本号
### 测试模型转换
tf.flags.DEFINE_string("ckpt_path", "model_cos_leaky_relu_0220", "path of source checkpoints")
tf.flags.DEFINE_string("pb_path", "servable_model", "path of servable models")
tf.flags.DEFINE_string("version", '00000123', "the number of model version")
# tf.flags.DEFINE_string("classes", 'LOC', "multi-models to be converted")
FLAGS = tf.flags.FLAGS
# classes = FLAGS.classes
input_checkpoint = FLAGS.ckpt_path
model_path = FLAGS.pb_path
# 版本号控制
count = FLAGS.version
modify = False
if not os.path.exists(model_path):
os.mkdir(model_path)
# 模型格式转换
restore_and_save(input_checkpoint, model_path)
参考链接:TensorFlow Serving使用指南 https://www.jianshu.com/p/d11a5c3dc757