目录
- OpenCV 调用tensorflow
- 概述
- 使用tensorflow
- 调用tensorflow
- OpenCV 调用 YOLO
OpenCV 调用tensorflow
概述
✔️OpenCV在DNN模块中支持直接调用tensorflow object detection训练导出的模型使用,支持的模型包括 - SSD - Faster-RCNN - Mask-RCNN
✔️ 利用这三种经典的对象检测网络,这样就可以实现从tensorflow模型训练、导出模型、在OpenCV DNN调用模型网络实现自定义对象检测的技术。
✔️ OpenCV3.4.1以上版本支持tensorflow1.11版本以上的对象检测框架(object detetion)模型导出使用,当前支持的模型包括以下:
- Model | Version | weights | Config |
- MobileNet-SSD v1 | 2017_11_17| weights | config |
- MobileNet-SSD v1 PPN | 2018_07_03| weights | config |
- MobileNet-SSD v2 | 2018_03_29| weights | config |
- Inception-SSD v2 | 2017_11_17| weights | config |
- Faster-RCNN Inception v2 | 2018_01_28| weights | config |
- Faster-RCNN ResNet-50 | 2018_01_28| weights | config |
- Mask-RCNN Inception v2 | 2018_01_28| weights | config |
✏️ 使用tensorflow object detection API框架进行迁移学习训练模型,导出预测图之后,然后通过OpenCV3.4.1以上版本提供几个python脚本导出graph配置文件,这样就可以在OpenCV DNN模块中使用tensorflow相关的模型了。
使用tensorflow
✔️使用tensorflow预测:
import tensorflow as tf
import cv2
# Read the graph.
model_dir = '../faster_rcnn_resnet50_coco_2018_01_28/frozen_inference_graph.pb'
with tf.gfile.FastGFile(model_dir, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Session() as sess:
# Restore session
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
# Read and preprocess an image.
img = cv2.imread('cat.jpg')
rows = img.shape[0]
cols = img.shape[1]
inp = cv2.resize(img, (300, 300))
inp = inp[:, :, [2, 1, 0]] # BGR2RGB
# Run the model
out = sess.run([sess.graph.get_tensor_by_name('num_detections:0'),
sess.graph.get_tensor_by_name('detection_scores:0'),
sess.graph.get_tensor_by_name('detection_boxes:0'),
sess.graph.get_tensor_by_name('detection_classes:0')],
feed_dict={'image_tensor:0': inp.reshape(1, inp.shape[0], inp.shape[1], 3)})
# Visualize detected bounding boxes.
num_detections = int(out[0][0])
for i in range(num_detections):
classId = int(out[3][0][i])
score = float(out[1][0][i])
bbox = [float(v) for v in out[2][0][i]]
if score > 0.8:
x = bbox[1] * cols
y = bbox[0] * rows
right = bbox[3] * cols
botto