Retinanet 较之 SSD, YOLO系one stage目标检测网络,在保证检测速度的基础上,很大的提高了在样本类别分布不平衡情况下的检测精度,这得益于He Kaiming等人所提出的Focal Loss
Paper可见
由于在工作或一些特定场景中,我们的数据很大可能不会像COCO, Pascal VOC这样的实验数据集有很多分布相对均匀的类,而基本是集中在某几种类别上。这时使用Retinanet 就比较合适了。
本文主要介绍如何使用Keras版本的Retinanet训练自己的数据集,代码及参考来自:
我的开发环境 Ubuntu 16.04 + python 3.6
Windows下亦可运行,但需要Microsoft Visual C++ 14.0支持
下载:Microsoft Visual C++ Build Tools
1 准备自己的数据集
根据自己系统下载labelImg工具,对自己的图片数据打上标签。数据集可自行爬取相关类目,或这里推荐一位同事自己制作数据集的博客,有非常详细的目标检测数据准备步骤及数据下载
2 搭建Retinanet环境
代码库下载地址https://github.com/fizyr/keras-retinanet,或git命令:
git clone https://github.com/fizyr/keras-retinanet.git
- 获得代码库后进入keras-retinanet文件夹
- 确认有未安装numpy
pip install numpy --user
- 在这个文件夹内运行下面代码来安装keras-retinanet库,确认你已经根据自己的系统需求安装了tensorflow
pip install . --user
或者你可以选择不安装库,仅在这个文件内运行相关训练测试代码,但是要先运行下述命令来编译Cython代码*(尽量不选择这种方式,仅能在此文件夹内运行相关代码,非常麻烦)*
python setup.py build_ext --inplace
3 生成CSV文件定义自己的数据集
训练自己的数据集需要至少两个CSV文件,一个文件包含标注数据,另一个则包含各个类别名及其对应的ID序号映射。
-
数据标注文件格式
路径,xmin,ymin,xmax,ymax,类别名 path/to/image.jpg,x1,y1,x2,y2,class_name
注:数据标注的CSV文件需一行仅包含一条标注信息,这和yolo标注数据集形成的.txt有所不同,即若一幅图片包含了多个标注的bounding boxes ,一行也仅显示一条标注,共分几行来进行标识,若没有则路径后仅有‘,’,示例如下:
/data/imgs/img_001.jpg,837,346,981,456,cow /data/imgs/img_002.jpg,215,312,279,391,cat /data/imgs/img_002.jpg,22,5,89,84,bird /data/imgs/img_003.jpg,,,,,
-
类别映射文件格式
类别,序号(从0开始) class_name,id 例: cow,0 cat,1 bird,2
写了个代码根据xml生成csv文件,下载:xml2csv.py
"""
运行方式:命令行 python xml2csv.py -i indir(图片及标注的母目录)
注:必须参数: -i 指定包含有图片及标注的母文件夹,图片及标注可不在同一子目录里,但名称必须一一对应
(图片格式默认.jpg,若为其他格式可见代码中注释自行修改)
可选参数: -p 交叉验证集拆分比,默认0.05
-t 生成训练集CSV文件名称,默认train.csv
-v 生成交叉验证集CSV文件名称,默认val.csv
-c 生成类别映射CSV文件名称,默认class.csv
"""
import os
import xml.etree.ElementTree as ET
import random
import math
import argparse
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--indir', type=str)
parser.add_argument('-p', '--percent', type=float, default=0.05)
parser.add_argument('-t', '--train', type=str, default='train.csv')
parser.add_argument('-v', '--val', type=str, default='val.csv')
parser.add_argument('-c', '--classes', type=str, default='class.csv')
args = parser.parse_args()
return args
#获取特定后缀名的文件列表
def get_file_index(indir, postfix):
file_list = []
for root, dirs, files in os.walk(indir):
for name in files:
if postfix in name:
file_list.append(os.path.join(root, name))
return file_list
#写入标注信息
def convert_annotation(csv, address_list):
cls_list = []
with open(csv, 'w') as f:
for i, address in enumerate(address_list):
in_file = open(address, encoding='utf8')
strXml =in_file.read(