方式一
from tensorflow.python.keras.models import load_model
import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.framework import graph_io
from util.neural_network_util import ManDist
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
from tensorflow.python.framework.graph_util import convert_variables_to_constants
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = convert_variables_to_constants(session, input_graph_def,
output_names, freeze_var_names)
return frozen_graph
h5_model_path = '/Users/jiaotongyu/Desktop/jiao/i-qdroid/models/SiameseLSTM.h5'
output_path = '.'
pb_model_name = 'unet.pb'
K.set_learning_phase(0)
net_model = load_model(h5_model_path, custom_objects={'ManDist': ManDist})
sess = K.get_session()
frozen_graph = freeze_session(K.get_session(), output_names=[net_model.output.op.name])
graph_io.write_graph(frozen_graph, output_path, pb_model_name, as_text=False)
方式二(可设置tags)
from util.neural_network_util import ManDist
import os
import tensorflow as tf
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.models import load_model
from tensorflow.python.keras.models import Model
def export_model(model, export_model_dir, model_version):
with tf.get_default_graph().as_default():
tensor_info_input = tf.saved_model.utils.build_tensor_info(model.input[0])
tensor_info_output = tf.saved_model.utils.build_tensor_info(model.output[0])
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'inputs': tensor_info_input},
outputs={'outpus': tensor_info_output},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
)
os.mkdir(export_model_dir)
export_path = os.path.join(tf.compat.as_bytes(export_model_dir), tf.compat.as_bytes(str(model_version)))
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
builder.add_meta_graph_and_variables(
sess=K.get_session(),
tags=[tf.saved_model.tag_constants.SERVING],
signature_def_map={
'predict': prediction_signature,
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature,
},
)
builder.save()
model = load_model("/Users/jiaotongyu/Desktop/jiao/i-qdroid/models/SiameseLSTM.h5", custom_objects={'ManDist': ManDist})
export_model(model, "test_model", "1")