参考 https://github.com/tensorflow/models/tree/master/object_detection
使用TensorFlow Object Detection API进行图像物体检测
准备
安装TensorFlow
参考 https://www.tensorflow.org/install/
如在Ubuntu下安装TensorFlow with GPU support, python 2.7版本
wget https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl pip install tensorflow_gpu-1.2.0-cp27-none-linux_x86_64.whl
配置TensorFlow Models
- 下载TensorFlow Models
git clone https://github.com/tensorflow/models.git
- 编译protobuf
# From tensorflow/models/ protoc object_detection/protos/*.proto --python_out=.
生成若干py文件在
object_detection/protos/
。- 添加PYTHONPATH
# From tensorflow/models/ export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
- 测试
# From tensorflow/models/ python object_detection/builders/model_builder_test.py
若成功,显示
OK
。准备数据
参考 https://github.com/tensorflow/models/blob/master/object_detection/g3doc/preparing_inputs.md
这里以
PASCAL VOC 2012
为例。- 下载并解压
# From tensorflow/models wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar tar -xvf VOCtrainval_11-May-2012.tar
- 生成TFRecord
# From tensorflow/models mkdir VOC2012 python object_detection/create_pascal_tf_record.py \ --label_map_path=object_detection/data/pascal_label_map.pbtxt \ --data_dir=VOCdevkit --year=VOC2012 --set=train \ --output_path=VOC2012/pascal_train.record python object_detection/create_pascal_tf_record.py \ --label_map_path=object_detection/data/pascal_label_map.pbtxt \ --data_dir=VOCdevkit --year=VOC2012 --set=val \ --output_path=VOC2012/pascal_val.record
得到
pascal_train.record
和pascal_val.record
。如果需要用自己的数据,则参考
create_pascal_tf_record.py
编写处理数据生成TFRecord的脚本。可参考 https://github.com/tensorflow/models/blob/master/object_detection/g3doc/using_your_own_dataset.md(可选)下载模型
官方提供了不少预训练模型( https://github.com/tensorflow/models/blob/master/object_detection/g3doc/detection_model_zoo.md ),这里以
ssd_mobilenet_v1_coco
以例。# From tensorflow/models wget http://download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_coco_11_06_2017.tar.gz tar zxf ssd_mobilenet_v1_coco_11_06_2017.tar.gz
训练
如果使用现有模型进行预测则不需要训练。
文件结构
为了方便查看文件,使用以下文件结构。
models ├── object_detection │ ├── VOC2012 │ │ ├── ssd_mobilenet_train_logs │ │ ├── ssd_mobilenet_val_logs │ │ ├── ssd_mobilenet_v1_voc2012.config │ │ ├── pascal_label_map.pbtxt │ │ ├── pascal_train.record │ │ └── pascal_val.record │ ├── infer.py │ └── create_pascal_tf_record.py ├── eval_voc2012.sh └── train_voc2012.sh
配置
参考 https://github.com/tensorflow/models/blob/master/object_detection/g3doc/configuring_jobs.md
这里使用SSD w/MobileNet,把
object_detection/samples/configs/ssd_mobilenet_v1_pets.config
复制到object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config
修改第9行为
num_classes: 20
。修改第158行为
fine_tune_checkpoint: "object_detection/ssd_mobilenet_v1_coco_11_06_2017/model.ckpt"
修改第177行为
input_path: "object_detection/VOC2012/pascal_train.record"
修改第179行和193行为
label_map_path: "object_detection/data/pascal_label_map.pbtxt"
修改第191行为
input_path: "object_detection/VOC2012/pascal_val.record"
训练
新建
tensorflow/models/train_voc2012.sh
,内容以下:python object_detection/train.py \ --logtostderr \ --pipeline_config_path=object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config \ --train_dir=object_detection/VOC2012/ssd_mobilenet_train_logs \ 2>&1 | tee object_detection/VOC2012/ssd_mobilenet_train_logs.txt &
进入
tensorflow/models/
,运行./train_voc2012.sh
即可训练。验证
可一边训练一边验证,注意使用其它的GPU或合理分配显存。
新建
tensorflow/models/eval_voc2012.sh
,内容以下:python object_detection/eval.py \ --logtostderr \ --pipeline_config_path=object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config \ --checkpoint_dir=object_detection/VOC2012/ssd_mobilenet_train_logs \ --eval_dir=object_detection/VOC2012/ssd_mobilenet_val_logs &
进入
tensorflow/models/
,运行CUDA_VISIBLE_DEVICES="1" ./train_voc2012.sh
即可验证(这里指定了第二个GPU)。可视化log
可一边训练一边可视化训练的log,可看到Loss趋势。
tensorboard --logdir ssd_mobilenet_train_logs/
可视化验证的log,可看到
Precision/mAP@0.5IOU
的趋势以及具体image的预测结果。tensorboard --logdir ssd_mobilenet_val_logs/ --port 6007
测试
导出模型
训练完成后得到一些checkpoint文件在
ssd_mobilenet_train_logs
中,如:- graph.pbtxt
- model.ckpt-200000.data-00000-of-00001
- model.ckpt-200000.info
- model.ckpt-200000.meta
其中meta保存了graph和metadata,ckpt保存了网络的weights。
而进行预测时只需模型和权重,不需要metadata,故可使用官方提供的脚本生成推导图。
python object_detection/export_inference_graph.py \ --input_type image_tensor \ --pipeline_config_path object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config \ --trained_checkpoint_prefix object_detection/VOC2012/ssd_mobilenet_train_logs/model.ckpt-200000 \ --output_directory object_detection/VOC2012
测试图片
运行
object_detection_tutorial.ipynb
并修改其中的各种路径即可。或自写编译inference脚本,如
tensorflow/models/object_detection/infer.py
import sys sys.path.append('..') import os import time import tensorflow as tf import numpy as np from PIL import Image from matplotlib import pyplot as plt from utils import label_map_util from utils import visualization_utils as vis_util PATH_TEST_IMAGE = sys.argv[1] PATH_TO_CKPT = 'VOC2012/frozen_inference_graph.pb' PATH_TO_LABELS = 'VOC2012/pascal_label_map.pbtxt' NUM_CLASSES = 21 IMAGE_SIZE = (18, 12) label_map = label_map_util.load_labelmap(PATH_TO_LABELS) categories = label_map_util.convert_label_map_to_categories( label_map, max_num_classes=NUM_CLASSES, use_display_name=True) category_index = label_map_util.create_category_index(categories) detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') config = tf.ConfigProto() config.gpu_options.allow_growth = True with detection_graph.as_default(): with tf.Session(graph=detection_graph, config=config) as sess: start_time = time.time() print(time.ctime()) image = Image.open(PATH_TEST_IMAGE) image_np = np.array(image).astype(np.uint8) image_np_expanded = np.expand_dims(image_np, axis=0) image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') boxes = detection_graph.get_tensor_by_name('detection_boxes:0') scores = detection_graph.get_tensor_by_name('detection_scores:0') classes = detection_graph.get_tensor_by_name('detection_classes:0') num_detections = detection_graph.get_tensor_by_name('num_detections:0') (boxes, scores, classes, num_detections) = sess.run( [boxes, scores, classes, num_detections], feed_dict={image_tensor: image_np_expanded}) print('{} elapsed time: {:.3f}s'.format(time.ctime(), time.time() - start_time)) vis_util.visualize_boxes_and_labels_on_image_array( image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8) plt.figure(figsize=IMAGE_SIZE) plt.imshow(image_np)
运行
infer.py test_images/image1.jpg
即可