首先在数据集中选择你所需要的类别,这里用base代表所需类别
import os
import shutil
import random
import xml.dom.minidom
xml_path = 'F:/data/Annotations/'
main_path = 'F:/data/main/'
total_xml = os.listdir(xml_path)
gt_dict = {}
novel = ['1', '6', '11']
base = ['2', '3', '4', '5', '10', '9', '8'] # base即为所需的类别
random.shuffle(total_xml)
ftrainval_base = open(main_path + 'trainval_base.txt', 'w')
fval_base = open(main_path + 'val_base.txt', 'w')
ftrain_base = open(main_path + 'train_base.txt', 'w')
if __name__ == '__main__':
cnt = 0
for xm in total_xml:
name = xm[:-4] + '\n'
xmlfile = xml_path + xm
dom = xml.dom.minidom.parse(xmlfile) # 读取xml文档
root = dom.documentElement # 得到文档元素对象
filenamelist = root.getElementsByTagName("filename")
filename = filenamelist[0].childNodes[0].data
objectlist = root.getElementsByTagName("object")
# flag = 0
haveflag = 0
## 遍历标注中所有目标
for objects in objectlist:
namelist = objects.getElementsByTagName("name")
objectname = namelist[0].childNodes[0].data
if objectname in base:
haveflag = 1
if haveflag == 1: # 图片中包含base类
ftrainval_base.write(name)
ftrainval_base.close()
print(cnt)
选出的图片中还含有其他你不想要的类别目标,这时候需要把图片中的目标用白色遮盖掉,并删除对应的xml标注
import os
import random
import xml.dom.minidom as xmldom
from tqdm import tqdm
import xml_parse
import cv2
# 主要目的是删除xml文件中*不需要的类别*的标注,以及将对应图片中部位填充成白色
main_path = 'F:/data/main/'
base_name = 'trainval_base.txt'
xml_path = 'F:/data/Annotations/'
img_path = 'F:/data/JPEGImages/'
erase_img_trainval = 'F:/data/erase_img_trainval/'
erase_xml_trainval = 'F:/data/erase_xml_trainval/'
if not os.path.exists(erase_xml_trainval): # 不存在处理后存放文件夹则创建
os.mkdir(erase_xml_trainval)
if not os.path.exists(erase_img_trainval):
os.mkdir(erase_img_trainval)
novel = ['1', '6', '11']
base = ['2', '3', '4', '5', '10', '9', '8']
def erase_xml_modify(head, objectlist):
obj = objectlist
i = 0
while i < obj.length:
cur_name = obj[i].getElementsByTagName("name")
objname = cur_name[0].childNodes[0].data
if objname not in base and objname is not '0': # 既不属于base也不属于其他
obj.remove(obj[i])
i = i - 1
i = i + 1
return head, obj
def erase_img_modify(origin_img, objectlist):
color = [255, 255, 255]
for obj in objectlist:
cur_name = obj.getElementsByTagName("name")
objname = cur_name[0].childNodes[0].data
if objname not in base:
bndbox = obj.getElementsByTagName('bndbox')[0]
xmin = bndbox.getElementsByTagName('xmin')[0]
XMIN = int(xmin.childNodes[0].data)
ymin = bndbox.getElementsByTagName('ymin')[0]
YMIN = int(ymin.childNodes[0].data)
xmax = bndbox.getElementsByTagName('xmax')[0]
XMAX = int(xmax.childNodes[0].data)
ymax = bndbox.getElementsByTagName('ymax')[0]
YMAX = int(ymax.childNodes[0].data)
for col in range(XMIN, XMAX):
for row in range(YMIN, YMAX):
origin_img[row, col] = color
return origin_img
def erase_img_xml_dataset(origin_img_path, origin_xml_path, e_img_path, e_xml_path):
origin_image = cv2.imread(origin_img_path)
domobj = xmldom.parse(origin_xml_path)
elementobj = domobj.documentElement
# name = elementobj.getElementsByTagName("name")
head, objectlist = xml_parse.voc_xml_parse(origin_xml_path)
m_img = erase_img_modify(origin_image, objectlist)
modify_img_path = e_img_path + '.jpg'
cv2.imwrite(modify_img_path, m_img)
modify_head, modify_objectlist = erase_xml_modify(head, objectlist)
modify_xml_path = e_xml_path + '.xml'
xml_parse.voc_xml_modify(modify_xml_path, modify_head, modify_objectlist)
if __name__ == '__main__':
file_list = []
# input file name which contains list of files separated by \n
with open(main_path + base_name, 'r+') as f:
list_file = f.read().splitlines()
input_list = file_list + list_file
for name in tqdm(input_list):
# name = each.split('.')[0]
origin_img_path = os.path.join(img_path, name + '.jpg')
origin_xml_path = os.path.join(xml_path, name + '.xml')
e_img_path = os.path.join(erase_img_trainval, name)
e_xml_path = os.path.join(erase_xml_trainval, name)
erase_img_xml_dataset(origin_img_path, origin_xml_path, e_img_path, e_xml_path)
另附xml_parse.py辅助修改xml文件
import xml.dom.minidom as xmldom
import os.path
def voc_xml_parse(xml_path):
object_list = []
domobj = xmldom.parse(xml_path)
elementobj = domobj.documentElement
folderobj = elementobj.getElementsByTagName("folder")[0]
filenameobj = elementobj.getElementsByTagName("filename")[0]
sourceobj = elementobj.getElementsByTagName("source")[0]
# ownerobj = elementobj.getElementsByTagName("owner")[0]
sizeobj = elementobj.getElementsByTagName("size")[0]
segmentedobj = elementobj.getElementsByTagName("segmented")[0]
head = {'folder': folderobj, 'filename': filenameobj, 'source': sourceobj, 'size': sizeobj,
'segmented': segmentedobj}
object_list = elementobj.getElementsByTagName("object")
return head, object_list
def voc_xml_modify(modify_xml_path, head, object_list):
dom = xmldom.Document()
root = dom.createElement('annotation')
dom.appendChild(root)
for obj in head.values():
root.appendChild(obj)
for obj in object_list:
root.appendChild((obj))
with open(modify_xml_path, 'w', encoding='utf-8') as f:
dom.writexml(f, addindent='\t', newl='\n', encoding='utf-8')
return