Tensorfow:导入.pb文件
示例代码
def create_model_graph(model_info):
""""
Creates a graph from saved GraphDef file and returns a Graph object.
Args:
model_info: Dictionary containing information about the model architecture.
Returns:
Graph holding the trained Inception network, and various tensors we'll be
manipulating.
"""
with tf.Graph().as_default() as graph:
model_path = os.path.join(FLAGS.model_dir, model_info['model_file_name'])
with gfile.FastGFile(model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
bottleneck_tensor, resized_input_tensor = (tf.import_graph_def(
graph_def,
name='',
return_elements=[
model_info['bottleneck_tensor_name'],
model_info['resized_input_tensor_name'],
]))
return graph, bottleneck_tensor, resized_input_tensor
相关API含义
gfile.FastGFile:
google的文件操作,和python 里面的open函数功能类似import_graph_def(graph_def, input_map=None, return_elements=None, name=None, op_dict=None, producer_op_list=None):
功能:导入参数graph_def中定义的tensorflow graph模型;
Imports the TensorFlow graph ingraph_def
into the PythonGraph
Operation
返回值:返回参数return_elements中定义的一系列的Operation和Tensor对象。
A list ofand/or
Tensorobjects from the imported graph,
return_elements`.
corresponding to the names in