ckpt模型转换为tf serving的saved model格式

最近这段时间又开始在弄部署问题,使用的是Google的Tensorflow serving框架,使用的环境是Ubuntu16.0.4+docker+tensorflow serving。如果需要知道这个框架搭建及使用,可以看我之前的博客,环境搭建模型部署测试模型部署时的GPU设置多模型在线部署。很久之后再弄,TF Serving的简单使用还是没有问题的,根据我之前的博客就可以做到模型的部署和服务申请以及多模型部署和测试。之前都是学习,所以都是简单的模型进行部署和简单的测试,在实际部署时会有很多问题。首先,在将训练好的模型转换为saved model的时候就会出现很多问题。这里主要讨论将TensorFlow的ckpt格式模型转换为saved model的时候出现的问题。

训练时保存为saved model格式

我们常说将TensorFlow模型格式转换为saved model格式,其实在TensorFlow 训练时就可以将其保存为saved model格式。

# 模型输入输出以及模型定义
x = tf.placeholder(dtype = tf.float32, shape = [None,224,224,3], name = 'input_image')
model =VGG16().build_vgg16(x,N_CLASSES)  # 用tf.layers实现的vgg网络
# model输出为softmax概率
'''
模型训练...
'''
tf.saved_model.simple_save(sess,
                    "./saved_models/model_"+str(step),
                        inputs={"MyInput": x},
                        outputs={"MyOutput": model})

#复杂形式
builder = tf.saved_model.builder.SavedModelBuilder("./saved_models1/model_"+str(step))

signature = predict_signature_def(inputs={'myInput': x},
                                  outputs={'myOutput': model})
builder.add_meta_graph_and_variables(sess=sess,
                                     tags=[tf.saved_model.tag_constants.SERVING],
                                     signature_def_map={'predict': signature})
builder.save()

"./saved_models/model_"+str(step) 是模型保存的位置,“MyInput”是保存的模型中x的名称,x是模型的输入,例如输入图像,model是网络的输出,例如softmax分类概率。

将ckpt格式转换为saved model格式

有时候,并没有在训练后将模型保存为saved model格式,重新训练又比较麻烦,需要将模型转换为saved model格式。其实质和训练时保存模型是一样的,需要在会话中恢复网络并加载参数到网络中,然后保存即可。

def restore_and_save(checkpoint_file, export_path):
    '''
    :param: checkpoint_file: 模型保存时的路径+模型名前缀
    :param: export_path: 转换后的模型的路径
    '''
    graph = tf.Graph()

    with graph.as_default():
        session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
        sess = tf.Session(config=session_conf)

        with sess.as_default():
            # 载入保存好的meta graph,恢复图中变量,通过SavedModelBuilder保存可部署的模型
            saver = tf.train.import_meta_graph(checkpoint_file+'.meta')
            saver.restore(sess, checkpoint_file)
            print(graph.get_name_scope())
            
            tf.summary.FileWriter(logs_dir, sess.graph)
            
            # 去除多余枝节
            converted_graph_def = tf.graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), ['output'])
            g = tf.graph_util.extract_sub_graph(converted_graph_def, ['output'])
            g = tf.graph_util.remove_training_nodes(g, protected_nodes=["input_x", "output"])
            
            
            
            builder = tf.saved_model.builder.SavedModelBuilder(export_path)
            """
            build_tensor_info
            建立一个基于提供的参数构造的TensorInfo protocol buffer,
            输入:tensorflow graph中的tensor;
            输出:基于提供的参数(tensor)构建的包含TensorInfo的protocol buffer

            get_operation_by_name
            通过name获取checkpoint中保存的变量,能够进行这一步的前提是在模型保存的时候给对应的变量赋予name
            """
            print(graph.get_operation_by_name("input_x").outputs[0].shape)
            input_image = tf.saved_model.utils.build_tensor_info(
                graph.get_operation_by_name("input_x").outputs[0])
            print(graph.get_operation_by_name("output").outputs[0].shape)
            output = tf.saved_model.utils.build_tensor_info(
                graph.get_operation_by_name("output").outputs[0])

            """
            signature_constants
            SavedModel保存和恢复操作的签名常量。
            在序列标注的任务中,这里的method_name是"tensorflow/serving/predict"
            """
            # 定义模型的输入输出,建立调用接口与tensor签名之间的映射
            labeling_signature = (
                tf.saved_model.signature_def_utils.build_signature_def(
                    inputs={
                        "input_image": input_image,
                    },
                    outputs={
                        "output": output
                    },
                    method_name="tensorflow/serving/predict"
                ))

            """
            add_meta_graph_and_variables
            建立一个Saver来保存session中的变量,输出对应的原图的定义,这个函数假设保存的变量已经被初始化;
            对于一个SavedModelBuilder,这个API必须被调用一次来保存meta graph;
            对于后面添加的图结构,可以使用函数 add_meta_graph()来进行添加
            """
            # 建立模型名称与模型签名之间的映射
            builder.add_meta_graph_and_variables(
                sess, [tf.saved_model.tag_constants.SERVING],
                # 保存模型的方法名,与客户端的request.model_spec.signature_name对应
                signature_def_map={
                    tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                        labeling_signature
                })

            builder.save()
            print("Build Done")

注释

1. 代码中“input_x”是ckpt模型中输入节点的名称,output是输出节点的名称。有时候我们得到一个模型这些信息不一定能完全知道,那么这个模型能够用么?绝对可以。可以通过 代码中的这行: tf.summary.FileWriter(logs_dir, sess.graph)将模型信息保存在logs路径下,通过Tensorboard查看模型信息,模型各个节点名城,shape等。然后再往下进行转换。

2. 代码中有几行是对模型的剪枝,这行看起来不是很重要,但是我在实际部署时,由于同事在训练时,将模型的输出也保存为占位符并保存在模型中,转换模型时没有任何问题,但是当我使用时,我必须要对这个输入进行传参才能够得到输入,服务申请报错是:该占位符没有传参,所以我使用了以上剪枝,然后就没有任何问题了。

 

如果该博客对你有用,点个赞(*—*)

  • 7
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 12
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值