训练好的模型想用于视频中物体的识别和跟踪,由于opencv用于视频和图片的处理非常方便,所以想用opencv直接导入tensorflow训练好的模型。
opencv从3.3版本开始正式支持DNN,可以直接导入caffe、tensorflow等框架训练好的模型,进而完成识别、检测等任务。
opencv加载tensorflow训练好的模型,采用readNetFromTensorflow(model,config),第一个参数对应训练好的模型文件frozen_inference_graph.pb的路径,第二个参数对应于一个生成的config文件,它其实是一个protobuf格式的文本的网络结构定义,下文会讲如何生成。加载完之后,使用blobFromImage函数,将图片转换成blob格式,网络接收输入数据后,通过forward()函数进行前向传播,即可得到网络输出的结果,检测视频其实只需要对视频中每一帧进行检测,即可得到对视频的检测结果。
1、官网git工程:
https://github.com/opencv/opencv(注意有3个branch,选择OpenCV的对应版本)
2、如何生成protobuf格式的网络结构定义config文件,
opencv提供了转换脚本,如下:
f_text_graph_ssd.py
tf_text_graph_faster_rcnn.py
tf_text_graph_mask_rcnn.py
首先根据你选取的网络模型,选择对应的脚本,我用的是ssd_mobilenet_v2的。这个脚本有三个参数,第一个是你训练好的frozen_inference_graph.pb的路径,第二个是训练时使用的pipeline_config文件的路径,第三个就是config文件的输出路径了,如下:
python tf_text_graph_ssd.py --input /path/to/model.pb --config /path/to/example.config --output /path/to/graph.pbtxt
有了graph.pbtxt这个文件,我们就可以用opencv的readNetFromTensorflow导入训练好的模型了,新建python文件,具体如下:
import cv2 as cv
cvNet = cv.dnn.readNetFromTensorflow('frozen_inference_graph.pb', 'graph.pbtxt')
img = cv.imread('example.jpg')
rows = img.shape[0]
cols = img.shape[1]
cvNet.setInput(cv.dnn.blobFromImage(img, size=(300, 300), swapRB=True, crop=False))
cvOut = cvNet.forward()
for detection in cvOut[0,0,:,:]:
score = float(detection[2])
if score > 0.3:
left = detection[3] * cols
top = detection[4] * rows
right = detection[5] * cols
bottom = detection[6] * rows
cv.rectangle(img, (int(left), int(top)), (int(right), int(bottom)), (23, 230, 210), thickness=2)
cv.putText(img, str(score), (int(right), int(bottom)), cv.FONT_HERSHEY_SIMPLEX, 1, (23, 230, 210), 2)
cv.imshow('img', img)
cv.waitKey()
运行时报错如下:
这是因为OpenCV的版本可能低了,我的是3.4.1.15,之后更新到了3.4.5.20版本(参考:https://github.com/opencv/opencv/issues/12518)
直接调用mobilenet_ssd v1 的权重和 生成的graph.pbtxt文件如下表:
(参考:https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API)
调用自己训练的模型:
先生成graph.pbtxt
命令:
$ python tf_text_graph_ssd.py --input /path/to/model.pb --config /path/to/example.config --output /path/to/graph.pbtxt
运行结果:
参考博客:
https://blog.csdn.net/eddyli/article/details/85064952
https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API