利用谷歌object_detection API实现物体识别(知识总结)

版权声明:本文为博主原创文章,转载请注明作者和出处。https://blog.csdn.net/xq920831/article/details/83502245

 

这两天想着实现一个实时物体识别的程序,正好了解到谷歌的object_detection API可以实时调用摄像头进行识别画面内的物体,所以就收集了相关资料学习了一下。

要准备的东西:

  • 安装谷歌object_detection API
  • 安装python3.5(本人的MacBook安装的3.6)
  • 安装tensorflow
  • 安装opencv包

 

首先,安装谷歌object_detection API(参考:https://www.jianshu.com/p/8841a47792df

1、安装Python 、TensorFlow和其他依赖项

pip install tensorflow
pip install pillow
pip install lxml
pip install jupyter
pip install matplotlib

2、安装 Protoc, 进入Protoc下载页,下载对应的编译好的zip包。

下载后bin目录下会有一个protoc二进制文件,覆盖到对应目录:

cp bin/protoc /usr/local/bin/protoc    

注意:应该拷贝到/usr/local/bin(可以读写)目录下不是/usr/bin(只读),否则会提示Operation not permitted, 这一步踩了好多坑。
3、从github上下载目标检测API的源代码

git clone https://github.com/tensorflow/models.git

4、编译Protobuf,进入tensorflow/models 目录,运行下面命令进行编译:

protoc object_detection/protos/*.proto --python_out=.

注意:一定要进入models目录下执行该命令。

5、在当前目录下,添加slim环境变量

export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim

6、测试目标检测API是否安装成功,下图表示安装成功

python object_detection/builders/model_builder_test.py

 

接着,主体代码如下:

# coding: utf-8   
  
import numpy as np  
import os  
import six.moves.urllib as urllib  
import sys  
import tarfile  
import tensorflow as tf  
import zipfile  
  
from collections import defaultdict  
from io import StringIO  
from matplotlib import pyplot as plt  
from PIL import Image  
  
import cv2                  #add 20170825  
cap = cv2.VideoCapture(0)   #add 20170825  
  
# This is needed since the notebook is stored in the object_detection folder.    
sys.path.append("..")  

# ## Object detection imports  
# Here are the imports from the object detection module.   
  
from object_detection.utils import label_map_util
  
from object_detection.utils import visualization_utils as vis_util
  

# # Model preparation   
  
# What model to download.  
MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017'  
#MODEL_NAME = 'faster_rcnn_resnet101_coco_11_06_2017'
#MODEL_NAME = 'ssd_inception_v2_coco_11_06_2017'
MODEL_FILE = MODEL_NAME + '.tar.gz'  
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'  
  
# Path to frozen detection graph. This is the actual model that is used for the object detection.  
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'  
  
# List of the strings that is used to add correct label for each box.  
PATH_TO_LABELS = os.path.join('models-master/research/object_detection/data', 'mscoco_label_map.pbtxt')
  
NUM_CLASSES = 90  
  
  
# ## Download Model  
  
# In[5]:  
  
opener = urllib.request.URLopener()  
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)  

tar_file = tarfile.open(MODEL_FILE)  
for file in tar_file.getmembers():  
  file_name = os.path.basename(file.name)  
  if 'frozen_inference_graph.pb' in file_name:  
    tar_file.extract(file, os.getcwd())  
  
  
# ## Load a (frozen) Tensorflow model into memory.  
  
# In[6]:  
  
detection_graph = tf.Graph()  
with detection_graph.as_default():  
  od_graph_def = tf.GraphDef()  
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:  
    serialized_graph = fid.read()  
    od_graph_def.ParseFromString(serialized_graph)  
    tf.import_graph_def(od_graph_def, name='')  
  
  
# ## Loading label map  
  
# In[7]:  
  
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)  
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)  
category_index = label_map_util.create_category_index(categories)  
  
  
# ## Helper code  
  
# In[8]:  
  
def load_image_into_numpy_array(image):  
  (im_width, im_height) = image.size  
  return np.array(image.getdata()).reshape(  
      (im_height, im_width, 3)).astype(np.uint8)  
  
  
# # Detection  
  
# In[9]:  
  
# For the sake of simplicity we will use only 2 images:  
# image1.jpg  
# image2.jpg  
# If you want to test the code with your images, just add path to the images to the TEST_IMAGE_PATHS.  
PATH_TO_TEST_IMAGES_DIR = 'test_images'  
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]  
  
# Size, in inches, of the output images.  
IMAGE_SIZE = (12, 8)  
  
  
# In[10]:  
  
with detection_graph.as_default():  
  with tf.Session(graph=detection_graph) as sess:  
    while True:    #for image_path in TEST_IMAGE_PATHS:    #changed 20170825  
      ret, image_np = cap.read()  
        
      # Expand dimensions since the model expects images to have shape: [1, None, None, 3]  
      image_np_expanded = np.expand_dims(image_np, axis=0)  
      image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')  
      # Each box represents a part of the image where a particular object was detected.  
      boxes = detection_graph.get_tensor_by_name('detection_boxes:0')  
      # Each score represent how level of confidence for each of the objects.  
      # Score is shown on the result image, together with the class label.  
      scores = detection_graph.get_tensor_by_name('detection_scores:0')  
      classes = detection_graph.get_tensor_by_name('detection_classes:0')  
      num_detections = detection_graph.get_tensor_by_name('num_detections:0')  
      # Actual detection.  
      (boxes, scores, classes, num_detections) = sess.run(  
          [boxes, scores, classes, num_detections],  
          feed_dict={image_tensor: image_np_expanded})  
      # Visualization of the results of a detection.  
      vis_util.visualize_boxes_and_labels_on_image_array(  
          image_np,  
          np.squeeze(boxes),  
          np.squeeze(classes).astype(np.int32),  
          np.squeeze(scores),  
          category_index,  
          use_normalized_coordinates=True,  
          line_thickness=8)  
      cv2.imshow('object detection', cv2.resize(image_np,(800,600)))  
      if cv2.waitKey(25) & 0xFF ==ord('q'):  
        cv2.destroyAllWindows()  
        break  
    
# In[ ]:  

这其中,我遇到了两个重要的问题:

1. PATH_TO_LABELS = os.path.join('models-master/research/object_detection/data', 'mscoco_label_map.pbtxt')

这个路径为下载的models包里面的数据(90个分类标签)。

2. TypeError: __new__() got an unexpected keyword argument 'serialized_options'

原因为终端的protobuf与pycharm中的protobuf版本不一致。调整为一致即可。

查询版本的语句为:protoc --version

参考:https://blog.csdn.net/Strive_For_Future/article/details/81809578

3. “from utils import label_map_util”ImportError:无法导入名称'label_map_util'

这个问题为object_detection没有加载到系统变量中,导致无法调用。

解决方法:运行程序之前进入object_detection目录下将路径加入环境变量中

export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim

然后将源程序的

from utils import label_map_util

from utils import visualization_utils as vis_util

改成:

from object_detection.utils import label_map_util

from object_detection.utils import visualization_utils as vis_util

即可。

 

 

  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值