使用TensorFlow Object Detection API进行图像物体检测

参考 https://github.com/tensorflow/models/tree/master/object_detection

使用TensorFlow Object Detection API进行图像物体检测

准备

  1. 安装TensorFlow

    参考 https://www.tensorflow.org/install/

    如在Ubuntu下安装TensorFlow with GPU support, python 2.7版本

    wget https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
    pip install tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
  2. 配置TensorFlow Models

    • 下载TensorFlow Models
    git clone https://github.com/tensorflow/models.git
    • 编译protobuf
    
    # From tensorflow/models/
    
    protoc object_detection/protos/*.proto --python_out=.

    生成若干py文件在object_detection/protos/

    • 添加PYTHONPATH
    
    # From tensorflow/models/
    
    export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
    • 测试
    
    # From tensorflow/models/
    
    python object_detection/builders/model_builder_test.py

    若成功,显示OK

  3. 准备数据

    参考 https://github.com/tensorflow/models/blob/master/object_detection/g3doc/preparing_inputs.md

    这里以PASCAL VOC 2012为例。

    • 下载并解压
    
    # From tensorflow/models
    
    wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
    tar -xvf VOCtrainval_11-May-2012.tar
    • 生成TFRecord
    
    # From tensorflow/models
    
    mkdir VOC2012
    python object_detection/create_pascal_tf_record.py \
        --label_map_path=object_detection/data/pascal_label_map.pbtxt \
        --data_dir=VOCdevkit --year=VOC2012 --set=train \
        --output_path=VOC2012/pascal_train.record
    python object_detection/create_pascal_tf_record.py \
        --label_map_path=object_detection/data/pascal_label_map.pbtxt \
        --data_dir=VOCdevkit --year=VOC2012 --set=val \
        --output_path=VOC2012/pascal_val.record

    得到pascal_train.recordpascal_val.record

    如果需要用自己的数据,则参考create_pascal_tf_record.py编写处理数据生成TFRecord的脚本。可参考 https://github.com/tensorflow/models/blob/master/object_detection/g3doc/using_your_own_dataset.md

  4. (可选)下载模型

    官方提供了不少预训练模型( https://github.com/tensorflow/models/blob/master/object_detection/g3doc/detection_model_zoo.md ),这里以ssd_mobilenet_v1_coco以例。

    
    # From tensorflow/models
    
    wget http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_11_06_2017.tar.gz
    tar zxf ssd_mobilenet_v1_coco_11_06_2017.tar.gz

训练

如果使用现有模型进行预测则不需要训练。

  1. 文件结构

    为了方便查看文件,使用以下文件结构。

    models
    ├── object_detection
    │   ├── VOC2012
    │   │   ├── ssd_mobilenet_train_logs
    │   │   ├── ssd_mobilenet_val_logs
    │   │   ├── ssd_mobilenet_v1_voc2012.config
    │   │   ├── pascal_label_map.pbtxt
    │   │   ├── pascal_train.record
    │   │   └── pascal_val.record
    │   ├── infer.py
    │   └── create_pascal_tf_record.py
    ├── eval_voc2012.sh
    └── train_voc2012.sh
  2. 配置

    参考 https://github.com/tensorflow/models/blob/master/object_detection/g3doc/configuring_jobs.md

    这里使用SSD w/MobileNet,把object_detection/samples/configs/ssd_mobilenet_v1_pets.config复制到object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config

    修改第9行为num_classes: 20

    修改第158行为fine_tune_checkpoint: "object_detection/ssd_mobilenet_v1_coco_11_06_2017/model.ckpt"

    修改第177行为input_path: "object_detection/VOC2012/pascal_train.record"

    修改第179行和193行为label_map_path: "object_detection/data/pascal_label_map.pbtxt"

    修改第191行为input_path: "object_detection/VOC2012/pascal_val.record"

  3. 训练

    新建tensorflow/models/train_voc2012.sh,内容以下:

    python object_detection/train.py \
        --logtostderr \
        --pipeline_config_path=object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config \
        --train_dir=object_detection/VOC2012/ssd_mobilenet_train_logs \
        2>&1 | tee object_detection/VOC2012/ssd_mobilenet_train_logs.txt &

    进入tensorflow/models/,运行./train_voc2012.sh即可训练。

  4. 验证

    可一边训练一边验证,注意使用其它的GPU或合理分配显存。

    新建tensorflow/models/eval_voc2012.sh,内容以下:

    python object_detection/eval.py \
        --logtostderr \
        --pipeline_config_path=object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config \
        --checkpoint_dir=object_detection/VOC2012/ssd_mobilenet_train_logs \
        --eval_dir=object_detection/VOC2012/ssd_mobilenet_val_logs &

    进入tensorflow/models/,运行CUDA_VISIBLE_DEVICES="1" ./train_voc2012.sh即可验证(这里指定了第二个GPU)。

  5. 可视化log

    可一边训练一边可视化训练的log,可看到Loss趋势。

    tensorboard --logdir ssd_mobilenet_train_logs/

    可视化验证的log,可看到Precision/mAP@0.5IOU的趋势以及具体image的预测结果。

    tensorboard --logdir ssd_mobilenet_val_logs/ --port 6007

测试

  1. 导出模型

    训练完成后得到一些checkpoint文件在ssd_mobilenet_train_logs中,如:

    • graph.pbtxt
    • model.ckpt-200000.data-00000-of-00001
    • model.ckpt-200000.info
    • model.ckpt-200000.meta

    其中meta保存了graph和metadata,ckpt保存了网络的weights。

    而进行预测时只需模型和权重,不需要metadata,故可使用官方提供的脚本生成推导图。

    python object_detection/export_inference_graph.py \
        --input_type image_tensor \
        --pipeline_config_path object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config \
        --trained_checkpoint_prefix object_detection/VOC2012/ssd_mobilenet_train_logs/model.ckpt-200000 \
        --output_directory object_detection/VOC2012
  2. 测试图片

    • 运行object_detection_tutorial.ipynb并修改其中的各种路径即可。

    • 或自写编译inference脚本,如tensorflow/models/object_detection/infer.py

      import sys
      sys.path.append('..')
      import os
      import time
      import tensorflow as tf
      import numpy as np
      from PIL import Image
      from matplotlib import pyplot as plt
      
      from utils import label_map_util
      from utils import visualization_utils as vis_util
      
      PATH_TEST_IMAGE = sys.argv[1]
      PATH_TO_CKPT = 'VOC2012/frozen_inference_graph.pb'
      PATH_TO_LABELS = 'VOC2012/pascal_label_map.pbtxt'
      NUM_CLASSES = 21
      IMAGE_SIZE = (18, 12)
      
      label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
      categories = label_map_util.convert_label_map_to_categories(
          label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
      category_index = label_map_util.create_category_index(categories)
      
      detection_graph = tf.Graph()
      with detection_graph.as_default():
          od_graph_def = tf.GraphDef()
          with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:
              serialized_graph = fid.read()
              od_graph_def.ParseFromString(serialized_graph)
              tf.import_graph_def(od_graph_def, name='')
      
      config = tf.ConfigProto()
      config.gpu_options.allow_growth = True
      
      with detection_graph.as_default():
          with tf.Session(graph=detection_graph, config=config) as sess:
              start_time = time.time()
              print(time.ctime())
              image = Image.open(PATH_TEST_IMAGE)
              image_np = np.array(image).astype(np.uint8)
              image_np_expanded = np.expand_dims(image_np, axis=0)
              image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
              boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
              scores = detection_graph.get_tensor_by_name('detection_scores:0')
              classes = detection_graph.get_tensor_by_name('detection_classes:0')
              num_detections = detection_graph.get_tensor_by_name('num_detections:0')
              (boxes, scores, classes, num_detections) = sess.run(
                  [boxes, scores, classes, num_detections],
                  feed_dict={image_tensor: image_np_expanded})
              print('{} elapsed time: {:.3f}s'.format(time.ctime(), time.time() - start_time))
              vis_util.visualize_boxes_and_labels_on_image_array(
                  image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores),
                  category_index, use_normalized_coordinates=True, line_thickness=8)
              plt.figure(figsize=IMAGE_SIZE)
              plt.imshow(image_np)

      运行infer.py test_images/image1.jpg即可

评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值