tensorflow ckpt模型转saved_model格式并进行模型预测

导读

tensorflow的checkpoint模型文件,只包含了模型的参数并不包含模型结构,为了方便使用tensorflow的serving进行部署,我们需要将checkpoint模型转换为saved_model格式

转换代码如下

def ckpt_to_pb(ckpt_path,output_pd_path):
"""
ckpt_path:checkpoint模型文件的目录
output_pd_path:savedmodel模型文件保存的目录
"""
	#加载模型的参数文件
    experiment_folder = "/tmp/"
    config = json.load(open(experiment_folder + 'config.json'))
	#根据的模型参数文件获取模型的结构(输入和输出)
    [x, y_, is_train, y, normalized_y, cost] = train.tf_define_model_and_cost(config)

    graph = tf.Graph()
    with tf.compat.v1.Session(graph=graph) as sess:
		#定义模型的输入输出节点
        SignatureDef = sm.signature_def_utils.build_signature_def(
            inputs={
                "x_input": sm.utils.build_tensor_info(x),
                "is_train": sm.utils.build_tensor_info(is_train)
            },
            outputs={
                "y_sigmoid": sm.utils.build_tensor_info(normalized_y)
            },
            method_name=sm.signature_constants.PREDICT_METHOD_NAME,
        )
		#加载checkpoint模型参数	
        loader = tf.compat.v1.train.import_meta_graph(ckpt_path + ".meta")
        loader.restore(sess,ckpt_path)
		
		#将checkpoint模型转换为savedmodel
        builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(output_pd_path)
        builder.add_meta_graph_and_variables(sess,tags = [tf.compat.v1.saved_model.tag_constants.SERVING],
                                             signature_def_map={sm.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: SignatureDef},
                                             strip_default_attrs=True)

        builder.save()

加载savedmodel模型进行预测

import tensorflow as tf

export_dir = "/save_model"
#加载savedmodel模型
imported = tf.saved_model.load(export_dir)
model = imported.signatures["serving_default"]
#模型预测
pred = model(x_input=tf.convert_to_tensor(input_array), is_train=tf.constant(False))
#获取模型的预测结果
pred = pred["y_sigmoid"].numpy()
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

修炼之路

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值