TensorFlow —— 30秒搞定物体检测

Google发布了新的TensorFlow物体检测API,包含了预训练模型,一个发布模型的jupyter notebook,一些可用于使用自己数据集对模型进行重新训练的有用脚本。
使用该API可以快速的构建一些图片中物体检测的应用。这里我们一步一步来看如何使用预训练模型来检测图像中的物体。

1、首先我们载入一些会使用的库

import numpy as np  
import os  
import six.moves.urllib as urllib  
import sys  
import tarfile  
import tensorflow as tf  
import zipfile  
  
from collections import defaultdict  
from io import StringIO  
from matplotlib import pyplot as plt  
from PIL import Image  

2、接下来进行环境设置

%matplotlib inline  
sys.path.append("..")  

3、物体检测载入

from utils import label_map_util  
  
from utils import visualization_utils as vis_util  

4、准备模型

变量 任何使用export_inference_graph.py工具输出的模型可以在这里载入,只需简单改变PATH_TO_CKPT指向一个新的.pb文件。这里我们使用“移动网SSD”模型。

MODEL_NAME = 'ssd_mobilenet_v1_coco_11_06_2017'  
MODEL_FILE = MODEL_NAME + '.tar.gz'  
DOWNLOAD_BASE = 'http://download.tensorflow.org/models/object_detection/'  
  
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'  
  
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')  
  
NUM_CLASSES = 90  

5、下载模型

opener = urllib.request.URLopener()  
opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE)  
tar_file = tarfile.open(MODEL_FILE)  
for file in tar_file.getmembers():  
    file_name = os.path.basename(file.name)  
    if 'frozen_inference_graph.pb' in file_name:  
        tar_file.extract(file, os.getcwd())  

6、将(frozen)TensorFlow模型载入内存

detection_graph = tf.Graph()  
with detection_graph.as_default():  
    od_graph_def = tf.GraphDef()  
    with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:  
        serialized_graph = fid.read()  
        od_graph_def.ParseFromString(serialized_graph)  
        tf.import_graph_def(od_graph_def, name='')  

7、载入标签图

标签图将索引映射到类名称,当我们的卷积预测5时,我们知道它对应飞机。这里我们使用内置函数,但是任何返回将整数映射到恰当字符标签的字典都适用。

label_map = label_map_util.load_labelmap(PATH_TO_LABELS)  
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)  
category_index = label_map_util.create_category_index(categories)  

8、辅助代码

def load_image_into_numpy_array(image):  
  (im_width, im_height) = image.size  
  return np.array(image.getdata()).reshape(  
      (im_height, im_width, 3)).astype(np.uint8)  

9、检测

PATH_TO_TEST_IMAGES_DIR = 'test_images'  
TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg'.format(i)) for i in range(1, 3) ]  
IMAGE_SIZE = (12, 8)  
with detection_graph.as_default():  
  
  with tf.Session(graph=detection_graph) as sess:  
    for image_path in TEST_IMAGE_PATHS:  
      image = Image.open(image_path)  
      # 这个array在之后会被用来准备为图片加上框和标签  
      image_np = load_image_into_numpy_array(image)  
      # 扩展维度,应为模型期待: [1, None, None, 3]  
      image_np_expanded = np.expand_dims(image_np, axis=0)  
      image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')  
      # 每个框代表一个物体被侦测到.  
      boxes = detection_graph.get_tensor_by_name('detection_boxes:0')  
      # 每个分值代表侦测到物体的可信度.  
      scores = detection_graph.get_tensor_by_name('detection_scores:0')  
      classes = detection_graph.get_tensor_by_name('detection_classes:0')  
      num_detections = detection_graph.get_tensor_by_name('num_detections:0')  
      # 执行侦测任务.  
      (boxes, scores, classes, num_detections) = sess.run(  
          [boxes, scores, classes, num_detections],  
          feed_dict={image_tensor: image_np_expanded})  
      # 图形化.  
      vis_util.visualize_boxes_and_labels_on_image_array(  
          image_np,  
          np.squeeze(boxes),  
          np.squeeze(classes).astype(np.int32),  
          np.squeeze(scores),  
          category_index,  
          use_normalized_coordinates=True,  
          line_thickness=8)  
      plt.figure(figsize=IMAGE_SIZE)  
      plt.imshow(image_np)  

在载入模型部分可以尝试不同的侦测模型以比较速度和准确度,将你想侦测的图片放入TEST_IMAGE_PATHS中运行即可。

关注我的技术公众号《漫谈人工智能》,每天推送优质文章

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

两只橙

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值