import tensorflow as tf
from tensorflow.python.tools.freeze_graph import freeze_graph
def freeze(ckpt_dir, meta_file_path, output_node_names):
# 导入结构、加载权重
saver = tf.train.import_meta_graph(meta_file_path)
with tf.Session() as sess:
saver.restore(sess, tf.train.latest_checkpoint(ckpt_dir))
tf.train.write_graph(sess.graph_def, 'temp/', 'temp.pb') # temp/、temp.pb
freeze_graph(input_graph='temp/temp.pb', # temp/temp.pb
input_checkpoint=tf.train.latest_checkpoint(ckpt_dir),
output_graph='frozen_graph.pb',
output_node_names=output_node_names,
# 以下为固定写法
clear_devices=True,
input_binary=False,
input_saver='',
restore_op_name='save/restore_all',
filename_tensor_name='save/Const:0',
initializer_nodes='')
if __name__ == '__main__':
freeze('./ckpt', './ckpt/model.meta', 'input,classifier/Softmax,regression/BiasAdd')
TensorFlow 将 checkpoint 冻结为 frozen_graph
最新推荐文章于 2024-05-19 15:12:42 发布