Tensorflow中的物体识别API的demo实现

算法的python实现 专栏收录该内容
1 篇文章 0 订阅

环境:python3.5 需要安装代码中对应的python库和tensorflow库
一、 简述
TensorFlow提供了一个物体识别的API开发包,可以较为准确的识别出图片或者视频中的不同种类的物体并完成跟踪。具体的算法原理没进行深究,以下通过两个demo程序简单描述下API的应用及实现,其中第一个为实体图片的识别demo;另一个为视频中的实体识别demo。具体的tensorflow软件开发包连接地址为:https://github.com/tensorflow/models/tree/477ed41e7e4e8a8443bc633846eb01e2182dc68a/object_detection
二、 实体图片的识别Demo
这个Demo程序基本照抄了TensorFlow提供的示例程序,有部分只是做了简单粗暴的修改,目的是实际验证下其API的实体识别和匹配能力。
首先需要为API寻找图片训练样本,tensorflow提供的图片训练样本为coco图片库,其中frozen_inference_graph.pb文件为训练模型最终生成的文件,这个文件是本身提供好的,如果需要训练自己的样本,还应提前完成pb训练文件的生成工作。Demo程序中采用了在线下载压缩包和解压的方法,第一次下载好后,下载压缩包的程序便可以被引去。
然后就是获取测试样本,需要将测试样本通过python的PIL处理最终由numpy输出为元组的形式。其实图片的数据化处理还可以交由opencv库完成,比如cv2.imread(image_path)可以直接获得图片的元组数据。
接下来就是tensorflow物体识别中最为核心的部分,由tf.Graph()完成图片的物体识别。获取image的TensorFlow存在形式、图片的匹配程度,图片所属的物体类别、图片的匹配框以及图片需要检测的各个部分。这些参数会存放到一个专门的数组中,便于后期显示。
最后就是采用matplotlib库实现显示功能,除了plt.imshow(image_np)语句外,最后又添加了pylab.show()才可以真正将图片显示出来。
源代码如下:

import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile
import pylab
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
import cv2
# This is needed to display the images.
#%matplotlib inline
# This is needed since the notebook is stored in the object_detection folder.
sys.path.append("..")
from utils import label_map_util
from utils import visualization_utils as vis_util
# What model to download.
MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017'
MODEL_FILE = MODEL_NAME + '.tar.gz'
# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')
NUM_CLASSES = 90
tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
  file_name = os.path.basename(file.name)
  if 'frozen_inference_graph.pb' in file_name:
    tar_file.extract(file, os.getcwd())

detection_graph = tf.Graph()
with detection_graph.as_default():
  od_graph_def = tf.GraphDef()
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
    serialized_graph = fid.read()
    od_graph_def.ParseFromString(serialized_graph)
    tf.import_graph_def(od_graph_def, name='')

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

def load_image_into_numpy_array(image):
  (im_width, im_height) = image.size
  return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)

# For the sake of simplicity we will use only 2 images:
# image1.jpg
# image2.jpg
# If you want to test the code with your images, just add path to the images to the TEST_IMAGE_PATHS.
PATH_TO_TEST_IMAGES_DIR = 'test_images'
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]

# Size, in inches, of the output images.
IMAGE_SIZE = (12, 8)

with detection_graph.as_default():
  with tf.Session(graph=detection_graph) as sess:
    for image_path in TEST_IMAGE_PATHS:
      image = Image.open(image_path)

      # the array based representation of the image will be used later in order to prepare the
      # result image with boxes and labels on it.
      image_np = load_image_into_numpy_array(image)
      #image_np = cv2.imread(image_path)

      # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
      image_np_expanded = np.expand_dims(image_np, axis=0)
      image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
      # Each box represents a part of the image where a particular object was detected.
      boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
      # Each score represent how level of confidence for each of the objects.
      # Score is shown on the result image, together with the class label.
      scores = detection_graph.get_tensor_by_name('detection_scores:0')
      classes = detection_graph.get_tensor_by_name('detection_classes:0')
      num_detections = detection_graph.get_tensor_by_name('num_detections:0')
      # Actual detection.
      (boxes, scores, classes, num_detections) = sess.run(
          [boxes, scores, classes, num_detections],
          feed_dict={image_tensor: image_np_expanded})
      # Visualization of the results of a detection.
      vis_util.visualize_boxes_and_labels_on_image_array(
          image_np,
          np.squeeze(boxes),
          np.squeeze(classes).astype(np.int32),
          np.squeeze(scores),
          category_index,
          use_normalized_coordinates=True,
          line_thickness=8)
      plt.figure(figsize=IMAGE_SIZE)

      plt.imshow(image_np)
      pylab.show()  

显示的效果如下:
这里写图片描述
这里写图片描述
三、 视频中的实体识别demo
视频识别是图片识别的一个变种,视频就是截获每一个帧,并将其转换为图片,然后利用上述提到的tensorflow实体识别核心算法完成对每个图片的识别与跟踪。
这里用到了moviepy库完成对视频的编辑功能。主要代码如下:

clip = VideoFileClip("video1.mp4").subclip(0,2)       #moviepy acquire the information of video
white_clip = clip.fl_image(process_image)             #NOTE: this function expects color images!!s
white_clip.write_videofile(white_output, audio=False) #the movie though tensorflow object_detection

HTML("""
<video width="960" height="540" controls>
  <source src="{0}">
</video>
""".format(white_output))
finalclip = VideoFileClip("video_out.mp4")
finalclip.write_gif("final.gif")

后续为了便于显示,又将mp4格式的文件转化成gif格式文件显示。Video1.mp4文件可以自行选择。
源代码如下:

# Import everything needed to edit/save/watch video clips
from moviepy.editor import VideoFileClip
import tensorflow as tf
from IPython.display import HTML
from PIL import Image

import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import zipfile
import pylab
from collections import defaultdict
from io import StringIO

sys.path.append("..")

from utils import label_map_util

from utils import visualization_utils as vis_util

# What model to download.
MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017'
MODEL_FILE = MODEL_NAME + '.tar.gz'


# Path to frozen detection graph. This is the actual model that is used for the object detection.
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'

# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')

NUM_CLASSES = 90

tar_file = tarfile.open(MODEL_FILE)
for file in tar_file.getmembers():
  file_name = os.path.basename(file.name)
  if 'frozen_inference_graph.pb' in file_name:
    tar_file.extract(file, os.getcwd())

detection_graph = tf.Graph()
with detection_graph.as_default():
  od_graph_def = tf.GraphDef()
  with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
    serialized_graph = fid.read()
    od_graph_def.ParseFromString(serialized_graph)
    tf.import_graph_def(od_graph_def, name='')

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)


def detect_objects(image_np, sess, detection_graph):
    # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
    image_np_expanded = np.expand_dims(image_np, axis=0)
    image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')

    # Each box represents a part of the image where a particular object was detected.
    boxes = detection_graph.get_tensor_by_name('detection_boxes:0')

    # Each score represent how level of confidence for each of the objects.
    # Score is shown on the result image, together with the class label.
    scores = detection_graph.get_tensor_by_name('detection_scores:0')
    classes = detection_graph.get_tensor_by_name('detection_classes:0')
    num_detections = detection_graph.get_tensor_by_name('num_detections:0')

    # Actual detection.
    (boxes, scores, classes, num_detections) = sess.run(
        [boxes, scores, classes, num_detections],
        feed_dict={image_tensor: image_np_expanded})

    # Visualization of the results of a detection.
    vis_util.visualize_boxes_and_labels_on_image_array(
        image_np,
        np.squeeze(boxes),
        np.squeeze(classes).astype(np.int32),
        np.squeeze(scores),
        category_index,
        use_normalized_coordinates=True,
        line_thickness=8)
    return image_np

def process_image(image):
    # NOTE: The output you return should be a color image (3 channel) for processing video below
    # you should return the final output (image with lines are drawn on lanes)
    with detection_graph.as_default():
        with tf.Session(graph=detection_graph) as sess:
            image_process = detect_objects(image, sess, detection_graph)
            return image_process

white_output = 'video_out.mp4'

clip = VideoFileClip("video1.mp4").subclip(0,2)       #moviepy acquire the information of video
white_clip = clip.fl_image(process_image)             #NOTE: this function expects color images!!s
white_clip.write_videofile(white_output, audio=False) #the movie though tensorflow object_detection

HTML("""
<video width="960" height="540" controls>
  <source src="{0}">
</video>
""".format(white_output))
finalclip = VideoFileClip("video_out.mp4")
finalclip.write_gif("final.gif")

显示的效果如下:
这里写图片描述

  • 2
    点赞
  • 22
    评论
  • 15
    收藏
  • 一键三连
    一键三连
  • 扫一扫,分享海报

©️2021 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值