importos,argparseimporttensorflowastf# The original freeze_graph function# from tensorflow.python.tools.freeze_graph import freeze_graphdir=os.path.dirname(os.path.realpath(__file__))deffreeze_graph(model_dir,output_node_names):"""Extract the sub graph defined by the output nodes and convert
all its variables into constant
Args:
model_dir: the root folder containing the checkpoint state file
output_node_names: a string, containing all the output node's names,
comma separated
"""ifnottf.gfile.Exists(model_dir):raiseAssertionError("Export directory doesn't exists. Please specify an export ""directory: %s"%model_dir)ifnotoutput_node_names:print("You need to supply the name of a node to --output_node_names.")return-1# We retrieve our checkpoint fullpathcheckpoint=tf.train.get_checkpoint_state(model_dir)input_checkpoint=checkpoint.model_checkpoint_path# We precise the file fullname of our freezed graphabsolute_model_dir="/".join(input_checkpoint.split('/')[:-1])output_graph=absolute_model_dir+"/frozen_model.pb"# We clear devices to allow TensorFlow to control on which device it will load operationsclear_devices=True# We start a session using a temporary fresh Graphwithtf.Session(graph=tf.Graph())assess:# We import the meta graph in the current default Graphsaver=tf.train.import_meta_graph(input_checkpoint+'.meta',clear_devices=clear_devices)# We restore the weightssaver.restore(sess,input_checkpoint)# We use a built-in TF helper to export variables to constantsoutput_graph_def=tf.graph_util.convert_variables_to_constants(sess,# The session is used to retrieve the weightstf.get_default_graph().as_graph_def(),# The graph_def is used to retrieve the nodesoutput_node_names.split(",")# The output node names are used to select the usefull nodes)# Finally we serialize and dump the output graph to the filesystemwithtf.gfile.GFile(output_graph,"wb")asf:f.write(output_graph_def.SerializeToString())print("%d ops in the final graph."%len(output_graph_def.node))returnoutput_graph_defif__name__=='__main__':parser=argparse.ArgumentParser()parser.add_argument("--model_dir",type=str,default="",help="Model folder to export")parser.add_argument("--output_node_names",type=str,default="",help="The name of the output nodes, comma separated.")args=parser.parse_args()freeze_graph(args.model_dir,args.output_node_names)