使用多个pb模型时,有时会遇到模型间节点相同的情况,造成图加载冲突,无法输出正确结果,更改节点信息较为麻烦,因此创建新图来解决
下面是直接使用默认图tf.get_default_graph(), 如果模型间无冲突可使用。
def __init__(self, graph_path, target_size=(300, 300), confidence_thresh=0.9): self.target_size = target_size self.confidence_threshold = confidence_thresh # load graph with tf.gfile.FastGFile(graph_path, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(graph_def, name='') self.graph = tf.get_default_graph() gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.03) self.fire_Detecction_sess = tf.Session(graph=self.graph, config=tf.ConfigProto(gpu_options=gpu_options)) self.input = self.fire_Detecction_sess.graph.get_tensor_by_name("input_1:0") self.output = self.fire_Detecction_sess.graph.get_tensor_by_name("output_1:0")
下面是模型间有节点同名冲突时,可用tf.Graph()为当前模型创建一个新图来解决问题。
def __init__(self, graph_path, target_size=(300, 300), confidence_thresh=0.9): self.target_size = target_size self.confidence_threshold = confidence_thresh self.graph = tf.Graph() self.graph_def = tf.GraphDef() with tf.gfile.FastGFile(graph_path, 'rb') as f: self.graph_def.ParseFromString(f.read()) with self.graph.as_default(): tf.import_graph_def(self.graph_def, name='') gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.03) self.fire_Detecction_sess = tf.Session(graph=self.graph, config=tf.ConfigProto(gpu_options=gpu_options)) self.input = self.fire_Detecction_sess.graph.get_tensor_by_name("input_1:0") self.output = self.fire_Detecction_sess.graph.get_tensor_by_name("output_1:0")