有时我们训练了多个模型,想合并使用它们:如检测模型和分类模型,pose模型和分类模型等,实际应用时模型之前存在着先后的串联关系或者并行关系等。
解决方法:
- 需要建立多个图,然后每个图导入一个模型,再针对每个图创建一个会话
- 用简单的串联合并关系创建新图
- 子图合并为大图(跟上面一种方法类似,但子图间没有任何关联,减少了模型之间的束缚,但也可以让它们存在关系)
这里主要介绍最后一种方法:
with tf.Graph().as_default() as g_combined:
with tf.Session(graph=g_combined) as sess:
graph_def_detect = load_def(detect_pb_path)
graph_def_seg= load_def(seg_pb_path)
input_image = tf.placeholder(dtype=tf.uint8,shape=[1,None,None,3], name="image")#定义新的网络输入
input_image1 = tf.placeholder(dtype=tf.float32,shape=[1,None,None,3], name="image1")
#将原始网络的输入映射到input_image(节点为:新的输入节点“image”)
detection = tf.import_graph_def(graph_def_detect, input_map={'image_tensor:0': input_image},return_elements=['detection_boxes:0', 'detection_scores:0','detection_classes:0','num_detections:0' ])
#新的输出节点为“detect”
tf.identity(detection, 'detect')
# second graph load
seg_predict = tf.import_graph_def(graph_def_seg, input_map={"create_inputs/batch:0": input_image1}, return_elements=["conv6/out_1:0"])
tf.identity(seg_predict, "seg_predict")
# freeze combined graph
g_combined_def = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["seg_predict","detect"])
#合成大图,生成新的pb
tf.train.write_graph(g_combined_def, out_pb_path, 'merge_model.pb', as_text=False)
最后调用新生成的.pb 即可。