见代码,注意在tensorflow2.0环境下
import tensorflow as tf
import numpy as np
import cv2
def load_pb(pb_path, img_array, input_tensor_name, out_tensor_name):
"""
通过加载pb格式的模型来预测
input_tensor_name like "conv2d_1_input:0"
"""
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(pb_path, "rb") as f:
# model_bytes = parse_model(f.read())
init = tf.global_variables_initializer()
output_graph_def.ParseFromString(f.read(