声明:部分代码参考https://blog.csdn.net/Jesse_Mx/article/details/65634482
本文侧重于通过处理数据集来锻炼python的应用能力
本文所有代码已进行开源,读者可以前往:https://download.csdn.net/download/tanghong1996/10596440进行下载
下载数据集
博主打算将SSD算法用于检测车载视频,用到的是CITY数据集(自建数据集)。
读者可以采用KITTI数据集,内容相似,本文主要针对数据进行前期处理,关于该数据集的说明本文就不进行介绍了。
进入官网,找到object一栏,准备下载数据集:
根据下载情况(博主把前四个都下载了,点开看过),进行SSD训练只需要下载第1个图片集 Download left color images of object data set (12 GB)和标注文件 Download training labels of object data set (5 MB) 就够了。然后将其解压,发现其中7481张训练图片有标注信息,而测试图片没有,这就是本次训练所使用的图片数量。由于SSD中训练脚本是基于VOC数据集格式的,所以我们需要把KITTI数据集做成PASCAL VOC的格式,其基本架构可以参看这篇博客:PASCAL VOC数据集分析 。根据SSD训练要求,博主在/home/th/data/中目录中建立一系列文件夹存放所需数据集和工具文件,具体如下:
PS.参看截图,数据要放在home目录下的data文件夹,不是caffe中的data文件夹,这个要注意,否则后续脚本出错。
(截图来源于小规模试验,图片只有400张,本人实际测试了2万张)
转换数据集
为了方便SSD进行训练,我们需要将KITTI数据集转换成PASCAL VOC的格式,细心的朋友可能已经发现,KITTI官网提供了一个工具: code to convert from KITTI to PASCAL VOC file format ,为啥不用呢?因为我觉得很难用,缺乏灵活性,还不如自己的Python转换工具好使。
转换KITTI类别
KITTI数据集总共20个类别,如果用于特定场景,20个类别确实多了。此次博主为数据集设置1个类别 ‘Car’,只不过标注信息中还有其他类型的车和人,直接略过有点浪费,博主希望将 ‘Van’, ‘Truck’, ‘Tram’ 合并到 ‘Car’ 类别中去,将 ‘Person_sitting’,’Cyclist’,’Pedestrian’ 合并到 ‘Pedestrian’ 类别中去,并删除Pedestrian类。这里使用的是modify_annotations_txt.py工具,源码如下:
modify_annotations_txt.py
import glob
import string
txt_list = glob.glob('./Labels/*.txt') # 存储Labels文件夹所有txt文件路径
def show_category(txt_list):
category_list= []
for item in txt_list:
try:
with open(item) as tdf:
for each_line in tdf:
labeldata = each_line.strip().split(' ') # 去掉前后多余的字符并把其分开
category_list.append(labeldata[0]) # 只要第一个字段,即类别
except IOError as ioerr:
print('File error:'+str(ioerr))
print(set(category_list)) # 输出集合
def merge(line):
each_line=''
for i in range(len(line)):
if i!= (len(line)-1):
each_line=each_line+line[i]+' '
else:
each_line=each_line+line[i] # 最后一条字段后面不加空格
each_line=each_line+'\n'
return (each_line)
print('before modify categories are:\n')
show_category(txt_list)
for item in txt_list:
new_txt=[]
try:
with open(item, 'r') as r_tdf:
for each_line in r_tdf:
labeldata = each_line.strip().split(' ')
if labeldata[0] in ['Truck','Van','Tram']: # 合并汽车类
labeldata[0] = labeldata[0].replace(labeldata[0],'Car')
if labeldata[0] in [</