python3 keras的h5模型转pb模型

方式一

from tensorflow.python.keras.models import load_model  # from keras.models import load_model
import tensorflow as tf
from tensorflow.python.keras import backend as K  # from keras import backend as K
from tensorflow.python.framework import graph_io

from util.neural_network_util import ManDist


def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    from tensorflow.python.framework.graph_util import convert_variables_to_constants
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = convert_variables_to_constants(session, input_graph_def,
                                                      output_names, freeze_var_names)
        return frozen_graph



h5_model_path = '/Users/jiaotongyu/Desktop/jiao/i-qdroid/models/SiameseLSTM.h5'
output_path = '.'
pb_model_name = 'unet.pb'


K.set_learning_phase(0)
net_model = load_model(h5_model_path, custom_objects={'ManDist': ManDist})

# print('input is :', net_model.input.name)
# print ('output is:', net_model.output.name)


sess = K.get_session()
frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])
graph_io.write_graph(frozen_graph, output_path, pb_model_name, as_text=False)

方式二(可设置tags)

from util.neural_network_util import ManDist
import os
import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.models import load_model
from tensorflow.python.keras.models import Model


def export_model(model, export_model_dir, model_version):
    with tf.get_default_graph().as_default():
        # prediction_signature
        tensor_info_input = tf.saved_model.utils.build_tensor_info(model.input[0])
        tensor_info_output = tf.saved_model.utils.build_tensor_info(model.output[0])

        prediction_signature = (
            tf.saved_model.signature_def_utils.build_signature_def(
                inputs={'inputs': tensor_info_input},  # Tensorflow.TensorInfo
                outputs={'outpus': tensor_info_output},
                method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
        )

        # set-up a builder
        os.mkdir(export_model_dir)
        export_path = os.path.join(tf.compat.as_bytes(export_model_dir), tf.compat.as_bytes(str(model_version)))
        builder = tf.saved_model.builder.SavedModelBuilder(export_path)
        builder.add_meta_graph_and_variables(
            sess=K.get_session(),
            tags=[tf.saved_model.tag_constants.SERVING],
            signature_def_map={
                'predict': prediction_signature,
                tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature,
            },
        )
        builder.save()


model = load_model("/Users/jiaotongyu/Desktop/jiao/i-qdroid/models/SiameseLSTM.h5", custom_objects={'ManDist': ManDist})
export_model(model, "test_model", "1")
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值