opencv中dnn模块的接口可以加载caffe,TensorFlow,darknet和onnx等模型,虽然目前pytorch非常流行,但由于pytorch属于动态图,部署的时候
没有静态图方便,虽然可以转换为onnx模型,但个人使用经验来看仍然存在很多接口兼容性问题。使用cv2.dnn.readNetFromTensorFlow加载tf模型
算是目前比较好的方法。下面记录采用cv2.dnn.readNetFromTensorFlow接口加载tf1.X版本的预训练模型并实现推理(似乎还不支持tf2.x版本),下载地址为:
http://download.tensorflow.org/models/object_detection/mask_rcnn_inception_v2_coco_2018_01_28.tar.gz
-
解压tf1.x版本mask rcnn预训练模型
-
生成.pbtxt描述文件
OpenCV DNN需要根据.pbtxt模型描述文件来解析Tensorflow的pb模型文件,实现网络模型的加载。OpenCV源码中提供了生成.pbtxt文件的python脚本文件,文件位于OpenCV安装路径
\opencv\sources\samples\dnn
文件夹下,根据模型类别选择相应的文件即可。
执行下列命令:
python tf_text_graph_mask_rcnn.py --input "D:\GoogleDownload\mask_rcnn_inception_v2_coco_2018_01_28\frozen_inference_graph.pb" --output "D:\GoogleDownload\mask_rcnn_inception_v2_coco_2018_01_28\mask_rcnn_inceptionv2_coco.pbtxt" --config "D:\GoogleDownload\mask_rcnn_inception_v2_coco_2018_01_28\pipeline.config"
-
模型推理。笔者使用的是
dnn.readNet()
接口,本来想像目标检测那样使用cv2.dnn_SegmentationModel
接口,但是报错了,还不知道问题所在,因此参考了官方例子mask_rcnn.py
进行实现。# -*- coding: utf-8 -*- """ Created on Fri Aug 14 22:21:53 2020 @author: 周文青 opencv dnn模块加载mask rcnn模型 """ import cv2 import cv2 as cv import matplotlib.pyplot as plt import numpy as np coco_names = r"F:\opencv\sources\samples\dnn\coco.names" args = { 'model':r"D:\GoogleDownload\mask_rcnn_inception_v2_coco_2018_01_28\frozen_inference_graph.pb", 'config':r"D:\GoogleDownload\mask_rcnn_inception_v2_coco_2018_01_28\mask_rcnn_inceptionv2_coco.pbtxt", 'width':800,'height':800,'thr':0.5} img_file = r"C:\Users\admin\Pictures\car.jpg" with open(coco_names,'rt') as f: classes = f.read().rstrip('\n').split('\n') frame=cv2.imread(img_file) # Load a network net = cv.dnn.readNet(cv.samples.findFile(args['model']), cv.samples.findFile(args['config'])) net.setPreferableBackend(cv.dnn.DNN_BACKEND_OPENCV) winName = 'Mask-RCNN in OpenCV' # cv.namedWindow(winName, cv.WINDOW_NORMAL) frameH = frame.shape[0] frameW = frame.shape[1] # Create a 4D blob from a frame. blob = cv.dnn.blobFromImage(frame, size=(args['width'], args['height']), swapRB=True, crop=False) # Run a model net.setInput(blob) boxes, masks = net.forward(['detection_out_final', 'detection_masks']) numClasses = masks.shape[1] numDetections = boxes.shape[2] # Load colors colors = None if colors: with open(colors, 'rt') as f: colors = [np.array(color.split(' '), np.uint8) for color in f.read().rstrip('\n').split('\n')] # Draw segmentation if not colors: # Generate colors colors = [np.array([0, 0, 0], np.uint8)] for i in range(1, numClasses + 1): colors.append((colors[i - 1] + np.random.randint(0, 256, [3], np.uint8)) / 2) del colors[0] def drawBox(frame, classId, conf, left, top, right, bottom): # Draw a bounding box. cv.rectangle(frame, (left, top), (right, bottom), (0, 255, 0)) label = '%.2f' % conf # Print a label of class. if classes: assert(classId < len(classes)) label = '%s: %s' % (classes[classId], label) labelSize, baseLine = cv.getTextSize(label, cv.FONT_HERSHEY_SIMPLEX, 0.5, 1) top = max(top, labelSize[1]) cv.rectangle(frame, (left, top - labelSize[1]), (left + labelSize[0], top + baseLine), (255, 255, 255), cv.FILLED) cv.putText(frame, label, (left, top), cv.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0)) boxesToDraw = [] for i in range(numDetections): box = boxes[0, 0, i] mask = masks[i] score = box[2] if score > args['thr']: classId = int(box[1]) left = int(frameW * box[3]) top = int(frameH * box[4]) right = int(frameW * box[5]) bottom = int(frameH * box[6]) left = max(0, min(left, frameW - 1)) top = max(0, min(top, frameH - 1)) right = max(0, min(right, frameW - 1)) bottom = max(0, min(bottom, frameH - 1)) boxesToDraw.append([frame, classId, score, left, top, right, bottom]) classMask = mask[classId] classMask = cv.resize(classMask, (right - left + 1, bottom - top + 1)) mask = (classMask > 0.5) roi = frame[top:bottom+1, left:right+1][mask] frame[top:bottom+1, left:right+1][mask] = (0.7 * colors[classId] + 0.3 * roi).astype(np.uint8) for box in boxesToDraw: drawBox(*box) # Put efficiency information. t, _ = net.getPerfProfile() label = 'Inference time: %.2f ms' % (t * 1000.0 / cv.getTickFrequency()) cv.putText(frame, label, (0, 15), cv.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0)) plt.imshow(frame[:,:,::-1]) plt.xticks([]) plt.yticks([]) plt.box(False) plt.tight_layout(pad=0.0) plt.show()
检测结果为:
通过调节阈值也可以得到更多的检测结果,只不过可行度有所下降。
另外这幅图在YOLOV3和YOLOV4模型下的检测结果也可以查看笔者的另一篇文章:
https://editor.csdn.net/md/?articleId=108015232
参考链接:
[1] https://blog.csdn.net/stjuliet/article/details/97294104