之前,我们介绍了Fatser R-CNN模型,在接下来的几篇文章,将通过Keras框架来完整实现Fatser R-CNN模型。数据集我们采用经典的VOC数据集。
这篇文章我们主要看下相关数据的准备工作,具体流程如下:
一、VOC数据集解析
VOC数据集的下载,,因为官网下载太慢,文章末尾处有提供百度网盘下载
下载解压后的文件目录如下:
对于目标检测任务,只需要用到Annotations,ImageSets,JPEGImages这三个目录。
1. Annotations:存放相关标注信息,每一张图片对应一个xml文件,具体xml内容如下:
<annotation>
<folder>VOC2012</folder>
<filename>2007_000033.jpg</filename>
<source>
<database>The VOC2007 Database</database>
<annotation>PASCAL VOC2007</annotation>
<image>flickr</image>
</source>
<size>
<width>500</width>
<height>366</height>
<depth>3</depth>
</size>
<segmented>1</segmented>
<object>
<name>aeroplane</name>
<pose>Unspecified</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>9</xmin>
<ymin>107</ymin>
<xmax>499</xmax>
<ymax>263</ymax>
</bndbox>
</object>
<object>
<name>aeroplane</name>
<pose>Left</pose>
<truncated>0</truncated>
<difficult>0</difficult>
<bndbox>
<xmin>421</xmin>
<ymin>200</ymin>
<xmax>482</xmax>
<ymax>226</ymax>
</bndbox>
</object>
</annotation>
2. ImageSets:我们只会用到ImageSets\Main下train.txt , val.txt, test.txt这三个文件,里面存储对应训练集,验证集,测试集的图片名称,文件格式如下:
3. JPEGImages:存储所有的图片数据
我们需要将下载来的VOC数据集解析成如下格式
具体代码实现如下
import os
import xml.etree.ElementTree as ET
from tqdm import tqdm
import pprint
def get_data(input_path):
'''
:param input_path: voc数据目录
:return:
image_data:解析后的数据集 list列表
classes_count:一个字典数据结构,key为对应类别名称,value对应为类别所对应的样本(标注框)个数
classes_mapping:一个字典数据结构,key为对应类别名称,value为对应类别的一个标识index
'''
image_data = []
classes_count = {} #一个字典,key为对应类别名称,value对应为类别所对应的样本(标注框)个数
classes_mapping = {} #一个字典数据结构,key为对应类别名称,value为对应类别的一个标识index
data_paths = os.path.join(input_path, "VOC2012")
print(data_paths)
annota_path = os.path.join(data_paths, "Annotations") # 数据标注目录
imgs_path = os.path.join(data_paths, "JPEGImages") # 图片目录
imgsets_path_train = os.path.join(data_paths, 'ImageSets', 'Main', 'train.txt')
imgsets_path_val = os.path.join(data_paths, 'ImageSets', 'Main', 'val.txt')
imgsets_path_test = os.path.join(data_paths, 'ImageSets', 'Main', 'test.txt')
train_files = [] # 训练集图片名称集合
val_files = [] # 验证集图片名称集合
test_files = [] # 测试集图片名称集合
with open(imgsets_path_train) as f:
for line in f:
# strip() 默认去掉字符串头尾的空格和换行符
train_files.append(line.strip() + '.jpg')
with open(imgsets_path_val) as f:
for line in f:
val_files.append(line.strip() + '.jpg')
# test-set not included in pascal VOC 2012
if os.path.isfile(imgsets_path_test):
with open(imgsets_path_test) as f:
for line in f:
test_files.append(line.strip() + '.jpg')
# 获得所有的标注文件路径,保存到annota_path_list列表中
annota_path_list = [os.path.join(annota_path, s) for s in os.listdir(annota_path)]
index = 0
# Tqdm 是一个快速,可扩展的Python进度条,
# 可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器 tqdm(iterator)
annota_path_list = tqdm(annota_path_list)
for annota_path in annota_path_list:
exist_flag = False
index += 1
annota_path_list.set_description("Processing %s" % annota_path.split(os.sep)[-1])
# 开始解析对应xml数据标注文件
et = ET.parse(annota_path)
element = et.getroot()
element_objs = element.findall("object") # 获取所有的object子元素
element_filename = element.find("filename").text # 对应图片名称
element_width = int(element.find("size").find("width").text) # 对应图片尺寸
element_height = int(element.find("size").find("height").text) # 对应图片尺寸
if (len(element_objs) > 0):
annotation_data = {"filepath": os.path.join(imgs_path, element_filename),
"width": element_width,
"height": element_height,
"image_id": index,
"bboxes": []} # bboxes 用来存放对应标注框的相关位置
if element_filename in train_files:
annotation_data["imageset"] = "train"
exist_flag = True
if element_filename in val_files:
annotation_data["imageset"] = "val"
exist_flag = True
if len(test_files) > 0:
if element_filename in test_files:
annotation_data["imageset"] = "test"
exist_flag = True
if not exist_flag:
continue
for element_obj in element_objs: # 遍历一个xml标注文件中的所有标注框