工业界中,常常使用TFserving框架部署深度学习模型,为了能够将训练好的深度学习模型进行部署,通常需要将ckpt格式文件转为pb格式,以下为具体代码:
import os
import tensorflow as tf
checkpoint_file = tf.train.latest_checkpoint("./saved_mode/model_name")
graph = tf.Graph()
with graph.as_default():
session_conf = tf.ConfigProto(allow_soft_placement=True,log_device_placement=False)
sess = tf.Session(config=session_conf)
with sess.as_default():
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess,checkpoint_file)
builder = tf.saved_model.builder.SavedModelBuilder("./saved_model/model_name/pb")
input_tensor= tf.saved_model.utils.build_tensor_info(graph.get_tensor_by_name("input_tensor:0"))
output_tensor= tf.saved_model.utils.build_tensor_info(graph.get_tensor_by_name("output_tensor:0"))
labeling_signature = (tf.saved_model.signature_def_utils.build_signature_def(inputs={"input_tensor:0":input_tensor},outputs={"output_tensor:0":output_tensor},method_name="tensorflow/serving/predict"))
builder.add_meta_grapg_and_variables(sess,[tf.saved_model.tag_constants.SERVING],signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:labeling_signature})
builder.save()
print("Build complete!")