tensorflow pb格式模型预测规范写法。不然耗时多,或者内存溢出

 背景:加载pb格式模型文件并预测

解决方法:要声明新的图和session并复用,

# 加载 
self.graph = tf.Graph()  # 为每个类(实例)单独创建一个graph
with self.graph.as_default():
     output_graph_def = tf.GraphDef()
     pb_path = wenlp_configs["sentence_matcher"]["pb_model_path"]
     with open(pb_path, "rb") as f:
          output_graph_def.ParseFromString(f.read())
          tf.import_graph_def(output_graph_def, name="")

        
self.sess = tf.Session(graph=self.graph) # 关键代码
with self.sess.as_default():
     self.embeding('你好') # 预测一下这样才可以

# 预测
 with self.sess.as_default():
      input_y = self.graph.get_tensor_by_name("input_y:0")
      qs_y_raw = self.graph.get_tensor_by_name("representation/qs_y_raw:0")
      qs_y_raw_out = self.sess.run(qs_y_raw, feed_dict={input_y: temp})
      vecs = qs_y_raw_out / (qs_y_raw_out ** 2).sum(axis=1, keepdims=True) ** 0.5
      return vecs[0]

 

©️2020 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页