使用tensorflow object detection进行训练检测。参考原始代码:https://github.com/tensorflow/models/tree/master/research
本博客以mobilenet-ssd-v2为例进行处理,通过换模型即可实现faster RCNN等的训练检测。
1、数据整理
对生成的数据集(整理成VOC格式),通过Annotations的数据数进行train、test、val、trainval.txt的生成
进入目录
cd VOCdevkit/VOC2012/
python data_segment.py
"""
data_segment.py
可自主设计数据集的比例,即trainval_percent,train_percent
"""
import os
import random
trainval_percent = 0.9
train_percent = 0.95
xmlfilepath = 'Annotations'
txtsavepath = 'ImageSets\Main'
total_xml = os.listdir(xmlfilepath)
num=len(total_xml)
list1=range(num)
tv=int(num*trainval_percent)
tr=int(tv*train_percent)
trainval= random.sample(list1,tv)
train=random.sample(trainval,tr)
ftrainval = open('ImageSets/Main/trainval.txt', 'w')
ftest = open('ImageSets/Main/test.txt', 'w')
ftrain = open('ImageSets/Main/train.txt', 'w')
fval = open('ImageSets/Main/val.txt', 'w')
for i in list1:
name=total_xml[i][:-4]+'\n'
if i in trainval:
ftrainval.write(name)
if i in train:
ftrain.write(name)
else:
fval.write(name)
else:
ftest.write(name)
ftrainval.close()
ftrain.close()
fval.close()
ftest .close()
note:
在使用数据集格式转化工具,生成的voc文件中,ImageSets/mains中,只含有trainval.txt。(可通过上述方式进行重新生成.txt文件,亦可暂时忽略,在后续生成record文件时,直接应用)
数据集的下载可参考自动驾驶数据集,同时能获得自动驾驶数据集与voc格式之间的转换。
2、安装依赖项
pip install -r requirements.txt
3、载入环境变量
protoc object_detection/protos/*.proto --python_out=.
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/object_detection
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
4、生成record文件
# 生成相应的train.record、val.record、test.record
python object_detection/dataset_tools/create_pascal_tf_record.py --data_dir=VOCdevkit/ --year=VOC2012 --set=val --label_map_path=object_detection/data/pascal_label_map.pbtxt --output_path=dataset/val.record
也可根据修改过的record生成文件,直接对数据集转化结果trainval.txt进行处理。
python object_detection/create_crj.py --data_dir=VOCdevkit/ --set=trainval --year=VOC2012 --label_map_path=object_detection/data/pascal_label_map.pbtxt
"""
前面部分,篇幅问题,省略,参考create_pascal_tf_record.py中函数
"""
def main(_):
if FLAGS.set not in SETS:
raise ValueError('set must be in : {}'.format(SETS))
if FLAGS.year not in YEARS:
raise ValueError(