python opencv载入TensorFlow训练的pb模型
感谢https://blog.csdn.net/heiheiya/article/details/88825135文章提供的帮助。
先看一下调用模型的代码。
import cv2
import time
# Pretrained classes in the model
classNames = {0: 'background',
1: 'blue', 2: 'hn', 3: 'mf', 4: 'red', 5: 'tls', 6: 'whh',
7: 'wq', 8: 'xh', 9: 'xb'}
def id_class_name(class_id, classes):
for key, value in classes.items():
if class_id == key:
return value
t1=time.time()
# Loading model
model = cv2.dnn.readNetFromTensorflow('frozen_inference_graph.pb', 'graph.pbtxt')
image = cv2.imread("images/yld.jpg")
image_height, image_width, _ = image.shape
model.setInput(cv2.dnn.blobFromImage(image, size=(300, 300), swapRB=True))
output = model.forward()
# print(output[0,0,:,:].shape)
for detection in output[0, 0, :, :]:
confidence = detection[2]
if confidence > .5:
class_id = detection[1]
class_name = id_class_name(class_id, classNames)
print(str(str(class_id) + " " + str(detection[2]) + " " + class_name))
box_x = detection[3] * image_width
box_y = detection[4] * image_height
box_width = detection[5] * image_width
box_height = detection[6] * image_height
print(box_x, box_y, box_width, box_height)
cv2.rectangle(image, (int(box_x), int(box_y)), (int(box_width), int(box_height)), (23, 230, 210), thickness=1)
cv2.putText(image, class_name, (int(box_x), int(box_y + .05 * image_height)), cv2.FONT_HERSHEY_SIMPLEX,
(.005 * image_width), (0, 0, 255))
cv2.imshow('image', image)
# cv2.imwrite("image_box_text.jpg",image)
t2=time.time()
print(t2-t1)
cv2.waitKey(0)
cv2.destroyAllWindows()
之前尝试其他博客的方式,在调用cv2.dnn.readNetFromTensorflow函数时都会有一个错误,解析graph.pbtxt文件失败。原因就是网上提供的pbtxt文件都已经过时了,因此最好的方式便是用官网提供的pb产生pbtxt文件的方式。为了方便起见,这是通往官网的链接:https://github.com/opencv/opencv/tree/master/samples/dnn
里面有两个文件需要下载,tf_text_graph_common.py与tf_text_graph_ssd.py(这是ssd模型转pbtxt的py文件,其他神经网络链接里也有转化代码文件)。
复制粘贴后,运行命令
python tf_text_graph_ssd.py \
--input /path/xxx.pb \
--config /path/xxx.config \
--output /path/xxx.pbtxt
/path可以是绝对路径,也可以是相对路径,自己设置。
再去调用cv2.dnn.readNetFromTensorflow,完美通过。