前言
由于收集数据集并标准过于麻烦,所以采用现有数据集进行训练。
使用Visual Object Classes Challenge 2012 (VOC2012)中的部分类别作为项目的数据集。
共包含20类目标,总计17125张图片。
从17125张图片中提取含有以下三类(person
、car
、bus
)的图片和标签,并将其转移到新的目录。
将标签中的car
、bus
修改为vehicle
。
代码
import xml.etree.ElementTree as ET
import shutil
import os
original_path = "E:\\dataset\\VOCtrainval_11-May-2012\\VOCdevkit\\VOC2012"
target_path = "E:\\Project\\keras-yolo3-person&vehicle\\dataset"
old_classes = ["person", "bus", "car"] # 需要的类别
# 寻找包含"person", "bus", "car"三个类别的文件,并将xml、jpg复制到指定文件夹
def search_file():
print("step1: search file.")
for file_name in os.listdir(os.path.join(original_path, "Annotations")):
old_ann_path = os.path.join(original_path, "Annotations", file_name)
old_img_path = os.path.join(original_path, "JPEGImages", file_name.split('.')[0] + '.jpg')
new_ann_path = os.path.join(target_path, "Annotations", file_name)
new_img_path = os.path.join(target_path, "JPEGImages", file_name.split('.')[0] + '.jpg')
print(old_ann_path)
# 打开xml文件进行解析
in_file = open(old_ann_path)
tree = ET.parse(in_file) # ET是一个xml文件解析库,ET.parse()打开xml文件。parse--"解析"
root = tree.getroot() # 获取根节点
for obj in root.findall('object'): # 找到根节点下所有“object”节点
name = str(obj.find('name').text) # 找到object节点下name子节点的值,不考虑part下的name。
if name in old_classes:
# 将符合的文件(xml、jpg)复制到指定文件夹
shutil.copyfile(old_ann_path, new_ann_path)
shutil.copyfile(old_img_path, new_img_path)
break
# 找到文件中的"person", "bus", "car",并删除其他类别和part标签。
def search_person_vehicle():
print("step2: filter classes.")
for file_name in os.listdir(os.path.join(target_path, "Annotations")):
file_path = os.path.join(target_path, "Annotations", file_name)
print(file_path)
in_file = open(file_path)
tree = ET.parse(in_file) # ET是一个xml文件解析库,ET.parse()打开xml文件。parse--"解析"
root = tree.getroot() # 获取根节点
for obj in root.findall('object'): # 找到根节点下所有“object”节点
name = str(obj.find('name').text) # 找到object节点下name子节点的值,不考虑part下的name。
# 判断:如果不是列出的,(这里可以用in对保留列表成员进行审查),则移除该object节点及其所有子节点。
if not (name in old_classes):
root.remove(obj)
# 移除person目标上的其他标签(hand、foot等)
for pa in obj.findall('part'):
obj.remove(pa)
# 将name为car、bus的节点,改为vehicle
if name in old_classes[1::]:
name = obj.find('name')
name.text = "vehicle"
tree.write(file_path)
if __name__ == '__main__':
search_file()
search_person_vehicle()