前言
- 任务目标:写个脚本从数据集的xml中筛选出需要训练的目标数据,对xml进行修改并另外保存。。由于数据量可能较大,写多个版本测试性能。自用
- 以下测试一共4285张图片,会被筛出528张
无并发版本
耗时:14.60274467599811秒
# -*- coding: utf-8 -*-
# @Description: 从xml中选出要保留的目标 运行时间14.60274467599811
# @Author: TuanZhangSama
# @Date: 2019-07-02 10:47:10
# @LastEditTime: 2019-07-02 13:23:28
# @LastEditors: TuanZhangSama
import xml.etree.ElementTree as ET
import os
import shutil
from tqdm import tqdm
import time
from contextlib import contextmanager
from pdb import set_trace
def get_all_xml_path(xml_dir:str,filter_=['.xml']):
#遍历文件夹下所有的xml
result=[]
for maindir,subdir,file_name_list in os.walk(xml_dir):
for filename in file_name_list:
ext=os.path.splitext(filename)[1]#返回扩展名
if ext in filter_:
result.append(os.path.join(maindir,filename))
return result
def deal_xml(xml_path,savedir,classes):
if not os.path.exists(savedir):
os.mkdir(savedir)
xml_name=os.path.basename(xml_path)
jpg_name=xml_name.replace('.xml','.jpg')
tree=ET.parse(xml_path)
root=tree.getroot()
for obj in root.findall('object'):
obj_name=obj.find('name').text
if obj_name not in classes:
root.remove(obj)
if root.find('object') is not None:
tree.write(os.path.join(savedir,xml_name))
shutil.copy(xml_path.replace('.xml','.jpg'),savedir)
@contextmanager
def timeblock(label:str):
r'''上下文管理测试代码块运行时间,需要
import time
from contextlib import contextmanager
Examples
----------
with timeblock('counting'):
....
'''
start = time.perf_counter()
try:
yield
finally:
end = time.perf_counter()
print('{} : {}'.format(label, end - start))
if __name__=='__main__':
with timeblock('time'):
xml_dir='/home/chiebotgpuhq/Share/gpu-server/data/game/new_train/cascade_dataset/train'
savedir='/home/chiebotgpuhq/MyCode/python/pytorch/mmdetection-master/testdata/save'
classes=('jyz_pl')
for i in tqdm(get_all_xml_path(xml_dir),total=len(get_all_xml_path(xml_dir))):
deal_xml(i,savedir,classes)
多进程版本
方法1
用 multiprocessing,开满8核,耗时:3.3734715720056556秒
# -*- coding: utf-8 -*-
# @Description: 处理xml筛选目标 多进程版 运行时间:3.3734715720056556
# @Author: TuanZhangSama
# @Date: 2019-07-02 13:26:36
# @LastEditTime: 2019-07-02 13:26:41
# @LastEditors: TuanZhangSama
import xml.etree.ElementTree as ET
import os
import shutil
from tqdm import tqdm
import time
from contextlib import contextmanager
from pdb import set_trace
from multiprocessing import Pool,freeze_support,cpu_count
def get_all_xml_path(xml_dir:str,filter_=['.xml']):
#遍历文件夹下所有的xml
result=[]
for maindir,subdir,file_name_list in os.walk(xml_dir):
for filename in file_name_list:
ext=os.path.splitext(filename)[1]#返回扩展名
if ext in filter_:
result.append(os.path.join(maindir,filename))
return result
def deal_xml(xml_path,savedir,classes):
if not os.path.exists(savedir):
os.mkdir(savedir)
xml_name=os.path.basename(xml_path)
jpg_name=xml_name.replace('.xml','.jpg')
tree=ET.parse(xml_path)
root=tree.getroot()
for obj in root.findall('object'):
obj_name=obj.find('name').text
if obj_name not in classes:
root.remove(obj)
if root.find('object') is not None:
tree.write(os.path.join(savedir,xml_name))
shutil.copy(xml_path.replace('.xml','.jpg'),savedir)
def deal_xml_batch(xmls_path,savedir,classes):
for i in xmls_path:
deal_xml(i,savedir,classes)
@contextmanager
def timeblock(label:str):
r'''上下文管理测试代码块运行时间,需要
import time
from contextlib import contextmanager
Examples
----------
with timeblock('counting'):
....
'''
start = time.perf_counter()
try:
yield
finally:
end = time.perf_counter()
print('{} : {}'.format(label, end - start))
if __name__=='__main__':
with timeblock('time'):
xml_dir='/home/chiebotgpuhq/Share/gpu-server/data/game/new_train/cascade_dataset/train'
savedir='/home/chiebotgpuhq/MyCode/python/pytorch/mmdetection-master/testdata/save'
classes=('jyz_pl')
freeze_support()
xmls_path=get_all_xml_path(xml_dir)
worker_num=cpu_count()
print('your CPU num is:',worker_num)
length=float(len(xmls_path))/float(worker_num)
#计算下标,尽可能均匀地划分输入文件的列表
indices=[int(round(i*length)) for i in range(worker_num+1)]
#生成每个进程要处理的子文件列表
sublists=[xmls_path[indices[i]:indices[i+1]] for i in range(worker_num)]
pool=Pool(processes=worker_num)
for i in range(worker_num):
pool.apply_async(deal_xml_batch,args=(sublists[i],savedir,classes))
pool.close()
pool.join()
方法2
用ProcessPoolExecutor 开满8核,耗时3.1331693999964045秒
这里的一些坑:
- map仅支持一个参数,即使用zip打包了参数也没法用
# -*- coding: utf-8 -*-
# @Description: 从xml中选出要保留的目标 运行时间
# @Author: TuanzhangSama
# @Date: 2019-07-02 10:47:10
# @LastEditTime: 2019-07-02 15:14:04
# @LastEditors: TuanzhangSama
import xml.etree.ElementTree as ET
import os
import shutil
from tqdm import tqdm
import time
from contextlib import contextmanager
from pdb import set_trace
from concurrent import futures
def get_all_xml_path(xml_dir:str,filter_=['.xml']):
#遍历文件夹下所有的xml
result=[]
for maindir,subdir,file_name_list in os.walk(xml_dir):
for filename in file_name_list:
ext=os.path.splitext(filename)[1]#返回扩展名
if ext in filter_:
result.append(os.path.join(maindir,filename))
return result
def deal_xml(xml_path,savedir,classes):
if not os.path.exists(savedir):
os.mkdir(savedir)
xml_name=os.path.basename(xml_path)
jpg_name=xml_name.replace('.xml','.jpg')
tree=ET.parse(xml_path)
root=tree.getroot()
for obj in root.findall('object'):
obj_name=obj.find('name').text
if obj_name not in classes:
root.remove(obj)
if root.find('object') is not None:
tree.write(os.path.join(savedir,xml_name))
shutil.copy(xml_path.replace('.xml','.jpg'),savedir)
@contextmanager
def timeblock(label:str):
r'''上下文管理测试代码块运行时间,需要
import time
from contextlib import contextmanager
Examples
----------
with timeblock('counting'):
....
'''
start = time.perf_counter()
try:
yield
finally:
end = time.perf_counter()
print('{} : {}'.format(label, end - start))
if __name__=='__main__':
with timeblock('time'):
xml_dir='/home/chiebotgpuhq/Share/gpu-server/data/game/new_train/cascade_dataset/train'
savedir='/home/chiebotgpuhq/MyCode/python/pytorch/mmdetection-master/testdata/save'
classes=('jyz_pl',)
xmls_path=get_all_xml_path(xml_dir)
with futures.ProcessPoolExecutor() as pool:
for x in (pool.submit(deal_xml,xml_path,savedir,classes) for xml_path in xmls_path):
pass
多线程版本
方法1
使用ThreadPoolExecutor,耗时: 1.895962769005564秒
使用如果 max_workers 为 None 或没有指定,将默认为机器处理器的个数,假如 ThreadPoolExecutor 则重于 I/O 操作而不是 CPU 运算,那么可以乘以 5,同时工作线程的数量可以比 ProcessPoolExecutor 的数量高。
# -*- coding: utf-8 -*-
# @Description: 从xml中选出要保留的目标 运行时间
# @Author: HuQiong
# @Date: 2019-07-02 10:47:10
# @LastEditTime: 2019-07-03 14:13:14
# @LastEditors: HuQiong
import xml.etree.ElementTree as ET
import os
import shutil
from tqdm import tqdm
import time
from contextlib import contextmanager
from pdb import set_trace
from concurrent import futures
def get_all_xml_path(xml_dir:str,filter_=['.xml']):
#遍历文件夹下所有的xml
result=[]
for maindir,subdir,file_name_list in os.walk(xml_dir):
for filename in file_name_list:
ext=os.path.splitext(filename)[1]#返回扩展名
if ext in filter_:
result.append(os.path.join(maindir,filename))
return result
def deal_xml(xml_path,savedir,classes=None,changelabel_dict=None):
if not os.path.exists(savedir):
os.mkdir(savedir)
xml_name=os.path.basename(xml_path)
jpg_name=xml_name.replace('.xml','.jpg')
tree=ET.parse(xml_path)
root=tree.getroot()
for obj in root.findall('object'):
obj_name=obj.find('name').text
if (changelabel_dict is not None) and (obj_name in changelabel_dict.keys()):
obj.find('name').text=changelabel_dict[obj_name]
continue
if (classes is not None) and (obj_name not in classes):
root.remove(obj)
continue
if root.find('object') is not None:
tree.write(os.path.join(savedir,xml_name))
shutil.copy(xml_path.replace('.xml','.jpg'),savedir)
@contextmanager
def timeblock(label:str):
start = time.perf_counter()
try:
yield
finally:
end = time.perf_counter()
print('{} : {}'.format(label, end - start))
def main(xml_dir,savedir,classes=None,changelabel_dict=None,work_num=64):
r'''筛选并修改xml标签,注意筛选或者修改之后的标签都会被保存出来
Parameters
----------
xmldir : (str)
xml文件所在目录,会搜索子目录
savedir : (str)
筛选和修改之后的数据存放的目录,会将修改之后的标签和图片拷贝过来
classes : (tuple = None)
指定了 哪些标签会被保留,若为None,则所有标签都会保留
changelabel_dict : (dict = None)
键表示需要修改的标签名,值修改之后的标签名,默认为None就是所有标签不修改
work_num: (int = 64)
最大线程数量
'''
xmls_path=get_all_xml_path(xml_dir)
with futures.ThreadPoolExecutor(work_num) as pool:
task_list=(pool.submit(deal_xml,xml_path,savedir,classes,changelabel_dict) for xml_path in xmls_path)
for x in tqdm(futures.as_completed(task_list),total=len(xmls_path)):
pass
if __name__=='__main__':
with timeblock('time'):
xml_dir='/home/gpu-server/project/data/game/bj_data'
savedir='/home/gpu-server/project/detection_test/dataset/0703_bujian'
classes=('bj','jsxs','sly_dmyw','yw_nc','yw_gkxfw','byq_hxq','sly_bjbmyw','jyz')
changelabel_dict={
'dlq_dzx':'xtsb',
'xhzz_znx':'xtsb',
'byq_dzx':'xtsb',
'dz_czjgx':'xtsb',
'dlq_czx':'xtsb',
'zhdq_jgx':'xtsb',
'byq_tg':'jyz',
'cqtg_bt':'jyz',
'zhdq_tg':'jyz'
}
main(xml_dir,savedir,classes,changelabel_dict=changelabel_dict)