TensorFlow物体识别——通过机器学习搭建属于自己的物体识别库

       由于每个项目需要检查的目标对象都不一定相同,一个大而全面的库固然是好多,但是如果仅仅是为了识别一个视角中的一种或少数种类的物体,庞大的库就会显得笨重,严重占用cpu资源和存储空间。

       所以这就需要我们通过机器机器学习完成特殊物体的库。本次我的检测目标物体是茶杯,所以我在网上收集了大量的图标并进行了标记,然后通过机器训练生产了自己的库,最后完成物体有有效识别。

        下面第一张图是我在网上收集的各种茶杯的图片,第二张是通过调用自己的库实现对茶杯的检测,除了茶杯其他的物体一概忽略,所以对目标物体能够更加快速准确的检测出二不存在其他的干扰。





完整的源教程分为大概五部分,大部分都是英文,视频的那部分需要翻墙(在YouTube上):

https://pythonprogramming.net/video-tensorflow-object-detection-api-tutorial/点击打开链接


第一部分:

参考源:

视频和基本介绍:

https://pythonprogramming.net/custom-objects-tracking-tensorflow-object-detection-api-tutorial/


LabelImg安装流程:

下载地址:https://github.com/tzutalin/labelImg

安装流程:Windows+anaconda

1、安装qt:(方法一:anaconda界面中点击qtconsole进行安装;方法二:在pycharm指定的环境下,于底部terminal窗口中输入:conda install pyqt=5

2、执行安装:进入下载后的文件目录中,pyrcc5 -o resources.py resources.qrc


第二部分:

参考源:

https://pythonprogramming.net/creating-tfrecord-files-tensorflow-object-detection-api-tutorial/


Xml文件合成为csv文件:

代码:

https://github.com/datitran/raccoon_dataset/blob/master/xml_to_csv.py

直接执行

生成对象检测模型的TFRecord文件:

代码:

https://github.com/datitran/raccoon_dataset/blob/master/generate_tfrecord.py

执行:

python generate_tfrecord.py --csv_input=data/train_labels.csv  --output_path=data/train.record

python generate_tfrecord.py --csv_input=data/test_labels.csv  --output_path=data/test.record

 

代码局部需要修改


第三部分:

参考源:

https://pythonprogramming.net/training-custom-objects-tensorflow-object-detection-api-tutorial/

 

ssd_mobilenet_v1_pets.config文件源:

https://github.com/tensorflow/models/blob/master/research/object_detection/samples/configs/ssd_mobilenet_v1_pets.config

 

sd_mobilenet_v1_coco_11_06_2017.tar.gz文件源:

http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_11_06_2017.tar.gz

 

python train.py --logtostderr --train_dir=training/ --pipeline_config_path=training/ssd_mobilenet_v1_pets.config


第四部分:

python export_inference_graph \
    --input_type image_tensor \
    --pipeline_config_path path/to/ssd_inception_v2.config \
    --trained_checkpoint_prefix path/to/model.ckpt \
    --output_directory path/to/exported_model_directory

 

python export_inference_graph.py \
    --input_type image_tensor \
    --pipeline_config_path training/ssd_mobilenet_v1_pets.config \
    --trained_checkpoint_prefix training/model.ckpt-553 \
    --output_directory cup_graph

 

python export_inference_graph.py --input_type image_tensor --pipeline_config_path training/ssd_mobilenet_v1_pets.config  --trained_checkpoint_prefix training/model.ckpt-553 --output_directory cup_graph

 

得到训练文件夹cup_graph,修改调用源码即可(图片形式或者视频形式)


第五部分:

修改应用程序,可以在Tensorflow 摄像头物体实时识别 基础上进行修改

附上完整代码

import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile

from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
from object_detection.utils import ops as utils_ops

import cv2
cap = cv2.VideoCapture(0)

if tf.__version__ < '1.4.0':
  raise ImportError('Please upgrade your tensorflow installation to v1.4.* or later!')

sys.path.append("..")

from utils import label_map_util

from utils import visualization_utils as vis_util

# What model to download.
MODEL_NAME = 'cup_graph'

# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'

# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('data', 'object-detection.pbtxt')

NUM_CLASSES = 1

detection_graph = tf.Graph()
with detection_graph.as_default():
  od_graph_def = tf.GraphDef()
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
    serialized_graph = fid.read()
    od_graph_def.ParseFromString(serialized_graph)
    tf.import_graph_def(od_graph_def, name='')

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

def load_image_into_numpy_array(image):
  (im_width, im_height) = image.size
  return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)

# For the sake of simplicity we will use only 2 images:
# image1.jpg
# image2.jpg
# If you want to test the code with your images, just add path to the images to the TEST_IMAGE_PATHS.
PATH_TO_TEST_IMAGES_DIR = 'test_images'
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(3, 6) ]

# Size, in inches, of the output images.
IMAGE_SIZE = (12, 8)

with detection_graph.as_default():
  with tf.Session(graph=detection_graph) as sess:
    while True:
      ret, image_np = cap.read()
      image_np = cv2.flip(image_np, 0)
      # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
      image_np_expanded = np.expand_dims(image_np, axis=0)
      image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
      # Each box represents a part of the image where a particular object was detected.
      boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
      # Each score represent how level of confidence for each of the objects.
      # Score is shown on the result image, together with the class label.
      scores = detection_graph.get_tensor_by_name('detection_scores:0')
      classes = detection_graph.get_tensor_by_name('detection_classes:0')
      num_detections = detection_graph.get_tensor_by_name('num_detections:0')
      # Actual detection.
      (boxes, scores, classes, num_detections) = sess.run(
          [boxes, scores, classes, num_detections],
          feed_dict={image_tensor: image_np_expanded})

      vis_util.visualize_boxes_and_labels_on_image_array(
          image_np,
          np.squeeze(boxes),
          np.squeeze(classes).astype(np.int32),
          np.squeeze(scores),
          category_index,
          use_normalized_coordinates=True,
          line_thickness=8)

      cv2.imshow('object detection', cv2.resize(image_np, (800,600)))
      if cv2.waitKey(25) & 0xFF == ord('q'):
        cv2.destroyAllWindows()
        break






  • 10
    点赞
  • 173
    收藏
    觉得还不错? 一键收藏
  • 11
    评论
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值