代码:
#!/usr/bin/env python # -*- coding: utf-8 -*- # @ Date : 2022/12/21 13:19 # @ Author : paperClub import os os.environ['TF_CPP_MIN_LOG_LEVEL']='2' # 去掉警告 import tensorflow if int(tensorflow.__version__[0]) == 2: import tensorflow.compat.v1 as tf tf.disable_v2_behavior() else: import tensorflow as tf def load_model(pb_model_file): graph = tf.Graph() with graph.as_default(): graph_def = tf.GraphDef() graph_def.ParseFromString(open(pb_model_file, 'rb').read()) tensors = tf.import_graph_def(graph_def, name="") session = tf.Session(graph=graph) with session.as_default(): with graph.as_default(): init = tf.global_variables_initializer() session.run(init) session.graph.get_operations() return session session = None if session is None: pb_model_file = "./tf2_model.pb" session = load_model(pb_model_file) layer_input = 'input_1:0' # 更新实际情况填写 layer_output = 'output' # 更新实际情况填写 img_input = '' # 更新实际情况填写 feed_input = session.graph.get_tensor_by_name(layer_input) feches = session.graph.get_tensor_by_name(layer_output) res = session.run(feches, feed_dict={feed_input: img_input})