一. 安装
Tensorflow object detection api是tensorflow官方出品的检测工具包,集成了像ssd、faster rcnn等检测算法,mobilenet、inception、resnet等backbone和fpn、ppn等方法,各模块之间能够通过组合的方式来work。
Github下载地址:https://github.com/tensorflow/models
解压models-master,主要的内容都在/research/目录下,里面有很多代码包括ocr、nlp、speech等,本节我们只需要关注object_detection,也就是目标检测部分。
代码安装步骤如下:
1)确保电脑上已安装了tensorflow环境+keras,建议tf版本>=1.6,注意tf版本要与CUDA、cudnn环境匹配,否则可能会有意想不到的错误。ps:tensorflow版本升级过快,前向兼容性不好(接口不一致),确实得吐槽。
2)编译protoc (建议版本3.4+)
cd research
protoc --python_out=. object_detection/protos/*.proto
3)安装research & slim
python setup.py install
cd slim
python setup.py install
4)添加系统路径(research目录)
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
也可以添加到系统变量:
vim ~/.bashrc
export PYTHONPATH=$PYTHONPATH:/path/to/research:/path/to/research/slim
source ~/.bashrc
5)测试是否安装成功(research目录)
python object_detection/builders/model_builder_test.py
安装成功会显示ok。
二. 数据准备
数据标注工具建议使用labelImg,采用xml进行数据保存,相关教程比较多,这里不再赘述。
这里需要说明的是:tensorflow需要tfrecord格式数据,需要在完成的标注数据基础上进行数据转换,这里将标注数据分为两个文件夹,train和test,文件夹下包含图片文件和xml。
需要通过脚本进行数据转换:首先将annotation转换成csv,然后将csv转换成tfrecord。
python xml_to_csv.py /path/to/train ./data/train_labels.csv
python xml_to_csv.py /path/to/test ./data/test_labels.csv
python generate_tfrecord.py --csv_input=data/train_labels.csv --output_path=train.record
python generate_tfrecord.py --csv_input=data/test_labels.csv --output_path=test.record
# xml_to_csv.py
# -*- coding: utf-8 -*-
import os, sys
import glob
import pandas as pd
import xml.etree.ElementTree as ET
def xml_to_csv(_path, _out_file):
xml_list = []
for xml_file in glob.glob(_path + '/*.xml'):
tree = ET.parse(xml_file)
root = tree.getroot()
for member in root.findall('object'):
value = (root.find('filename').text,
int(root.find('size')[0].text),
int(root.find('size')[1].text),
member[0].text,
int(member[4][0].text),