TensorFlow Serving模型转换

Tensorflow训练的模型,如果想使用TensorFlow Serving进行部署,需要将ckpt模型转换为pb模型。

一、模型格式转变

1、原文件格式:
在这里插入图片描述
2、新文件格式:
在这里插入图片描述

二、模型转化代码:

做3个地方修改即可。
1、需要结合自己网络结构的输入输出参数进行修改
2、定义模型的输入输出,建立调用接口与tensor签名之间的映射
3、设置原模型目录、新模型目录、版本号

#coding:utf-8
import sys, os, io
import tensorflow as tf

def restore_and_save(input_checkpoint, export_path_base):
    checkpoint_file = tf.train.latest_checkpoint(input_checkpoint)
    graph = tf.Graph()

    with graph.as_default():
        session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
        sess = tf.Session(config=session_conf)

        with sess.as_default():
            # 载入保存好的meta graph,恢复图中变量,通过SavedModelBuilder保存可部署的模型
            saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
            saver.restore(sess, checkpoint_file)
            print(graph.get_name_scope())
            for node in graph.as_graph_def().node:
                print('node.name:', node.name)
            export_path_base = export_path_base
            export_path = os.path.join(
                tf.compat.as_bytes(export_path_base),
                tf.compat.as_bytes(str(count)))
            print('Exporting trained model to', export_path)
            builder = tf.saved_model.builder.SavedModelBuilder(export_path)

            # 建立签名映射,需要包括计算图中的placeholder(ChatInputs, SegInputs, Dropout)和我们需要的结果(project/logits,crf_loss/transitions)
            """
            build_tensor_info:建立一个基于提供的参数构造的TensorInfo protocol buffer,
            输入:tensorflow graph中的tensor;
            输出:基于提供的参数(tensor)构建的包含TensorInfo的protocol buffer
                        get_operation_by_name:通过name获取checkpoint中保存的变量,能够进行这一步的前提是在模型保存的时候给对应的变量赋予name
            """
            
			#1.需要结合自己网络结构的输入输出参数进行修改
            left_inputs =tf.saved_model.utils.build_tensor_info(graph.get_operation_by_name("left").outputs[0])
            right_inputs =tf.saved_model.utils.build_tensor_info(graph.get_operation_by_name("right").outputs[0])
			#grcp调用方式会使用到,http方式不会使用的
            output_prob =tf.saved_model.utils.build_tensor_info(graph.get_operation_by_name("output_prob").outputs[0])

            """
            signature_constants:SavedModel保存和恢复操作的签名常量。
            在序列标注的任务中,这里的method_name是"tensorflow/serving/predict"
            """
            
            # 2.定义模型的输入输出,建立调用接口与tensor签名之间的映射
            labeling_signature = (
                tf.saved_model.signature_def_utils.build_signature_def(
                    inputs={
                        "leftinputs":
                            left_inputs,
                        "rightinputs":
                            right_inputs,
                    },
                    outputs={
                        "output":
                            output_prob
                    },
                    method_name="tensorflow/serving/predict"))

            """
            tf.group : 创建一个将多个操作分组的操作,返回一个可以执行所有输入的操作
            """
            legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')

            """
            add_meta_graph_and_variables:建立一个Saver来保存session中的变量,
                                          输出对应的原图的定义,这个函数假设保存的变量已经被初始化;
                                          对于一个SavedModelBuilder,这个API必须被调用一次来保存meta graph;
                                          对于后面添加的图结构,可以使用函数 add_meta_graph()来进行添加
            """
            # 建立模型名称与模型签名之间的映射
            builder.add_meta_graph_and_variables(
                sess, [tf.saved_model.tag_constants.SERVING],
                # 保存模型的方法名,与客户端的request.model_spec.signature_name对应
                signature_def_map={
                    tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                       labeling_signature},
                legacy_init_op=legacy_init_op)

            builder.save()
            print("Build Done")
            
#3.设置原模型目录、新模型目录、版本号
### 测试模型转换
tf.flags.DEFINE_string("ckpt_path",     "model_cos_leaky_relu_0220",             "path of source checkpoints")
tf.flags.DEFINE_string("pb_path",       "servable_model",             "path of servable models")
tf.flags.DEFINE_string("version",      '00000123',              "the number of model version")
# tf.flags.DEFINE_string("classes",       'LOC',          "multi-models to be converted")
FLAGS = tf.flags.FLAGS

# classes = FLAGS.classes
input_checkpoint = FLAGS.ckpt_path
model_path = FLAGS.pb_path

# 版本号控制
count = FLAGS.version
modify = False
if not os.path.exists(model_path):
    os.mkdir(model_path)

# 模型格式转换
restore_and_save(input_checkpoint, model_path)

参考链接:TensorFlow Serving使用指南 https://www.jianshu.com/p/d11a5c3dc757

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值