导读
tensorflow的checkpoint
模型文件,只包含了模型的参数并不包含模型结构,为了方便使用tensorflow的serving
进行部署,我们需要将checkpoint模型转换为saved_model
格式
转换代码如下
def ckpt_to_pb(ckpt_path,output_pd_path):
"""
ckpt_path:checkpoint模型文件的目录
output_pd_path:savedmodel模型文件保存的目录
"""
#加载模型的参数文件
experiment_folder = "/tmp/"
config = json.load(open(experiment_folder + 'config.json'))
#根据的模型参数文件获取模型的结构(输入和输出)
[x, y_, is_train, y, normalized_y, cost] = train.tf_define_model_and_cost(config)
graph = tf.Graph()
with tf.compat.v1.Session(graph=graph) as sess:
#定义模型的输入输出节点
SignatureDef = sm.signature_def_utils.build_signature_def(
inputs={
"x_input": sm.utils.build_tensor_info(x),
"is_train": sm.utils.build_tensor_info(is_train)
},
outputs={
"y_sigmoid": sm.utils.build_tensor_info(normalized_y)
},
method_name=sm.signature_constants.PREDICT_METHOD_NAME,
)
#加载checkpoint模型参数
loader = tf.compat.v1.train.import_meta_graph(ckpt_path + ".meta")
loader.restore(sess,ckpt_path)
#将checkpoint模型转换为savedmodel
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(output_pd_path)
builder.add_meta_graph_and_variables(sess,tags = [tf.compat.v1.saved_model.tag_constants.SERVING],
signature_def_map={sm.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: SignatureDef},
strip_default_attrs=True)
builder.save()
加载savedmodel模型进行预测
import tensorflow as tf
export_dir = "/save_model"
#加载savedmodel模型
imported = tf.saved_model.load(export_dir)
model = imported.signatures["serving_default"]
#模型预测
pred = model(x_input=tf.convert_to_tensor(input_array), is_train=tf.constant(False))
#获取模型的预测结果
pred = pred["y_sigmoid"].numpy()