定义变量时,顺便打印变量的输入和输出的name
import tensorflow as tf
import time
def convert_model():
checkpoint_prefix = '/pvc/model/deepfm/model.ckpt-1122'
export_dir = '/pvc/model/deepfm_pb_{}'.format(int(time.time()))
graph = tf.Graph()
config = tf.compat.v1.ConfigProto(allow_soft_placement=True, log_device_placement=True)
with tf.compat.v1.Session(graph=graph, config=config) as sess:
# restore from checkpoint
loader = tf.train.import_meta_graph(checkpoint_prefix + '.meta')
loader.restore(sess, checkpoint_prefix)
inputs = {'IteratorGetNext:{}'.format(i): graph.get_tensor_by_name('IteratorGetNext:{}'.format(i)) for i in range(34)}
outputs = {'Sigmoid:0': graph.get_tensor_by_name('Sigmoid:0')}
#输入输出签名
signature = tf.saved_model.predict_signature_def(inputs=inputs,
outputs=outputs)
# export checkpoint to savedmodel
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(sess,
['serve'],
strip_default_attrs=True,
signature_def_map={'predict': signature})
builder.save()
if __name__ == '__main__':
convert_model()
参考
tensorflow-ckpt转savemode记录
https://zhuanlan.zhihu.com/p/113734249