1、概述
上一讲简单的讲了目标检测的原理以及Tensorflow Object Detection API的安装,这一节继续讲Tensorflow Object Detection API怎么用。
2、COCO数据集介绍
COCO数据集是微软发布的一个可以用来进行图像识别训练的数据集,图像中的目标都经过精确的segmentation进行位置定位,COCO数据集包括90类目标。Object Detection API默认提供了5个预训练模型,都是使用COCO数据集训练的,分别为
SSD + MobileNet
Inception V2 + SSD
ResNet101 + R-CNN
ResNet101 + Faster R-CNN
Inception-ResNet V2 + Faster R-CNN
3、下载模型
这个例子中,我们使用基于COCO上训练的ssd_mobilenet_v1_coco模型对任意图片进行识别。打开以下链接,
下载第一个模型。然后,将其解压在object_detection目录下。接下来,写代码。
4、导入模块
首先在my_object_detection目录下新建文件demo1.py。
#encoding:utf-8
import tensorflow as tf
import numpy as np
import os
from matplotlib import pyplot as plt
from PIL import Image
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_utils
5、指定文件路径等
#下载下来的模型的目录
MODEL_DIR = 'object_detection/ssd_mobilenet_v1_coco_2018_01_28'
#下载下来的模型的文件
MODEL_CHECK_FILE = os.path.join(MODEL_DIR, 'frozen_inference_graph.pb')
#数据集对于的label
MODEL_LABEL_MAP = os.path.join('object_detection/data', 'mscoco_label_map.pbtxt')
#数据集分类数量,可以打开mscoco_label_map.pbtxt文件看看
MODEL_NUM_CLASSES = 90
#这里是获取实例图片文件名,将其放到数组中
PATH_TO_TEST_IMAGES_DIR = 'object_detection/test_images'
TEST_IMAGES_PATHS = [os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 6)]
#输出图像大小,单位是in
IMAGE_SIZE= (12, 8)
6、导入模型
tf.reset_default_graph()
#将模型读取到默认的图中
with tf.gfile.GFile(MODEL_CHECK_FI