python--xml文件批量筛选出目标(多版本效率对比)

前言

  • 任务目标:写个脚本从数据集的xml中筛选出需要训练的目标数据,对xml进行修改并另外保存。。由于数据量可能较大,写多个版本测试性能。自用
  • 以下测试一共4285张图片,会被筛出528张

无并发版本

耗时:14.60274467599811秒

# -*- coding: utf-8 -*-

# @Description: 从xml中选出要保留的目标 运行时间14.60274467599811
# @Author: TuanZhangSama
# @Date: 2019-07-02 10:47:10
# @LastEditTime: 2019-07-02 13:23:28
# @LastEditors: TuanZhangSama

import xml.etree.ElementTree as ET
import os
import shutil
from tqdm import tqdm
import time
from contextlib import contextmanager
from pdb import set_trace

def get_all_xml_path(xml_dir:str,filter_=['.xml']):
    #遍历文件夹下所有的xml
    result=[]
    for maindir,subdir,file_name_list in os.walk(xml_dir):
        for filename in file_name_list:
            ext=os.path.splitext(filename)[1]#返回扩展名
            if ext in filter_:
                result.append(os.path.join(maindir,filename))
    return result

def deal_xml(xml_path,savedir,classes):
    if not os.path.exists(savedir):
        os.mkdir(savedir)
    xml_name=os.path.basename(xml_path)
    jpg_name=xml_name.replace('.xml','.jpg')
    tree=ET.parse(xml_path)
    root=tree.getroot()
    for obj in root.findall('object'):
        obj_name=obj.find('name').text
        if obj_name not in classes:
            root.remove(obj)
    if root.find('object') is not None:
        tree.write(os.path.join(savedir,xml_name))
        shutil.copy(xml_path.replace('.xml','.jpg'),savedir)

@contextmanager
def timeblock(label:str):
    r'''上下文管理测试代码块运行时间,需要
        import time
        from contextlib import contextmanager

    Examples
    ----------
        with timeblock('counting'):
            ....
    '''
    start = time.perf_counter()
    try:
        yield
    finally:
        end = time.perf_counter()
        print('{} : {}'.format(label, end - start))

if __name__=='__main__':
    with timeblock('time'):
        xml_dir='/home/chiebotgpuhq/Share/gpu-server/data/game/new_train/cascade_dataset/train'
        savedir='/home/chiebotgpuhq/MyCode/python/pytorch/mmdetection-master/testdata/save'
        classes=('jyz_pl')
        for i in tqdm(get_all_xml_path(xml_dir),total=len(get_all_xml_path(xml_dir))):
            deal_xml(i,savedir,classes)

多进程版本

方法1

用 multiprocessing,开满8核,耗时:3.3734715720056556秒

# -*- coding: utf-8 -*-

# @Description: 处理xml筛选目标 多进程版 运行时间:3.3734715720056556
# @Author: TuanZhangSama
# @Date: 2019-07-02 13:26:36
# @LastEditTime: 2019-07-02 13:26:41
# @LastEditors: TuanZhangSama

import xml.etree.ElementTree as ET
import os
import shutil
from tqdm import tqdm
import time
from contextlib import contextmanager
from pdb import set_trace
from multiprocessing import Pool,freeze_support,cpu_count

def get_all_xml_path(xml_dir:str,filter_=['.xml']):
    #遍历文件夹下所有的xml
    result=[]
    for maindir,subdir,file_name_list in os.walk(xml_dir):
        for filename in file_name_list:
            ext=os.path.splitext(filename)[1]#返回扩展名
            if ext in filter_:
                result.append(os.path.join(maindir,filename))
    return result

def deal_xml(xml_path,savedir,classes):
    if not os.path.exists(savedir):
        os.mkdir(savedir)
    xml_name=os.path.basename(xml_path)
    jpg_name=xml_name.replace('.xml','.jpg')
    tree=ET.parse(xml_path)
    root=tree.getroot()
    for obj in root.findall('object'):
        obj_name=obj.find('name').text
        if obj_name not in classes:
            root.remove(obj)
    if root.find('object') is not None:
        tree.write(os.path.join(savedir,xml_name))
        shutil.copy(xml_path.replace('.xml','.jpg'),savedir)

def deal_xml_batch(xmls_path,savedir,classes):
    for i in xmls_path:
        deal_xml(i,savedir,classes)

@contextmanager
def timeblock(label:str):
    r'''上下文管理测试代码块运行时间,需要
        import time
        from contextlib import contextmanager

    Examples
    ----------
        with timeblock('counting'):
            ....
    '''
    start = time.perf_counter()
    try:
        yield
    finally:
        end = time.perf_counter()
        print('{} : {}'.format(label, end - start))

if __name__=='__main__':
    with timeblock('time'):
        xml_dir='/home/chiebotgpuhq/Share/gpu-server/data/game/new_train/cascade_dataset/train'
        savedir='/home/chiebotgpuhq/MyCode/python/pytorch/mmdetection-master/testdata/save'
        classes=('jyz_pl')
        freeze_support()
        xmls_path=get_all_xml_path(xml_dir)
        worker_num=cpu_count()
        print('your CPU num is:',worker_num)
        length=float(len(xmls_path))/float(worker_num)
        #计算下标,尽可能均匀地划分输入文件的列表
        indices=[int(round(i*length)) for i in range(worker_num+1)]

        #生成每个进程要处理的子文件列表
        sublists=[xmls_path[indices[i]:indices[i+1]] for i in range(worker_num)]
        pool=Pool(processes=worker_num)
        for i in range(worker_num):
            pool.apply_async(deal_xml_batch,args=(sublists[i],savedir,classes))
        pool.close()
        pool.join()

方法2

用ProcessPoolExecutor 开满8核,耗时3.1331693999964045秒
这里的一些坑:

  • map仅支持一个参数,即使用zip打包了参数也没法用
# -*- coding: utf-8 -*-

# @Description: 从xml中选出要保留的目标 运行时间
# @Author: TuanzhangSama
# @Date: 2019-07-02 10:47:10
# @LastEditTime: 2019-07-02 15:14:04
# @LastEditors: TuanzhangSama

import xml.etree.ElementTree as ET
import os
import shutil
from tqdm import tqdm
import time
from contextlib import contextmanager
from pdb import set_trace
from concurrent import futures

def get_all_xml_path(xml_dir:str,filter_=['.xml']):
    #遍历文件夹下所有的xml
    result=[]
    for maindir,subdir,file_name_list in os.walk(xml_dir):
        for filename in file_name_list:
            ext=os.path.splitext(filename)[1]#返回扩展名
            if ext in filter_:
                result.append(os.path.join(maindir,filename))
    return result

def deal_xml(xml_path,savedir,classes):
    if not os.path.exists(savedir):
        os.mkdir(savedir)
    xml_name=os.path.basename(xml_path)
    jpg_name=xml_name.replace('.xml','.jpg')
    tree=ET.parse(xml_path)
    root=tree.getroot()
    for obj in root.findall('object'):
        obj_name=obj.find('name').text
        if obj_name not in classes:
            root.remove(obj)
    if root.find('object') is not None:
        tree.write(os.path.join(savedir,xml_name))
        shutil.copy(xml_path.replace('.xml','.jpg'),savedir)

@contextmanager
def timeblock(label:str):
    r'''上下文管理测试代码块运行时间,需要
        import time
        from contextlib import contextmanager

    Examples
    ----------
        with timeblock('counting'):
            ....
    '''
    start = time.perf_counter()
    try:
        yield
    finally:
        end = time.perf_counter()
        print('{} : {}'.format(label, end - start))

if __name__=='__main__':
    with timeblock('time'):
        xml_dir='/home/chiebotgpuhq/Share/gpu-server/data/game/new_train/cascade_dataset/train'
        savedir='/home/chiebotgpuhq/MyCode/python/pytorch/mmdetection-master/testdata/save'
        classes=('jyz_pl',)
        xmls_path=get_all_xml_path(xml_dir)
        with futures.ProcessPoolExecutor() as pool:
            for x in (pool.submit(deal_xml,xml_path,savedir,classes) for xml_path in xmls_path):
                pass

多线程版本

方法1

使用ThreadPoolExecutor,耗时: 1.895962769005564秒
使用如果 max_workers 为 None 或没有指定,将默认为机器处理器的个数,假如 ThreadPoolExecutor 则重于 I/O 操作而不是 CPU 运算,那么可以乘以 5,同时工作线程的数量可以比 ProcessPoolExecutor 的数量高。

# -*- coding: utf-8 -*-

# @Description: 从xml中选出要保留的目标 运行时间
# @Author: HuQiong
# @Date: 2019-07-02 10:47:10
# @LastEditTime: 2019-07-03 14:13:14
# @LastEditors: HuQiong

import xml.etree.ElementTree as ET
import os
import shutil
from tqdm import tqdm
import time
from contextlib import contextmanager
from pdb import set_trace
from concurrent import futures

def get_all_xml_path(xml_dir:str,filter_=['.xml']):
    #遍历文件夹下所有的xml
    result=[]
    for maindir,subdir,file_name_list in os.walk(xml_dir):
        for filename in file_name_list:
            ext=os.path.splitext(filename)[1]#返回扩展名
            if ext in filter_:
                result.append(os.path.join(maindir,filename))
    return result

def deal_xml(xml_path,savedir,classes=None,changelabel_dict=None):
    if not os.path.exists(savedir):
        os.mkdir(savedir)
    xml_name=os.path.basename(xml_path)
    jpg_name=xml_name.replace('.xml','.jpg')
    tree=ET.parse(xml_path)
    root=tree.getroot()
    for obj in root.findall('object'):
        obj_name=obj.find('name').text
        if (changelabel_dict is not None) and (obj_name in changelabel_dict.keys()):
            obj.find('name').text=changelabel_dict[obj_name]
            continue
        if (classes is not None) and (obj_name not in classes):
            root.remove(obj)
            continue

    if root.find('object') is not None:
        tree.write(os.path.join(savedir,xml_name))
        shutil.copy(xml_path.replace('.xml','.jpg'),savedir)


@contextmanager
def timeblock(label:str):
    start = time.perf_counter()
    try:
        yield
    finally:
        end = time.perf_counter()
        print('{} : {}'.format(label, end - start))

def main(xml_dir,savedir,classes=None,changelabel_dict=None,work_num=64):
    r'''筛选并修改xml标签,注意筛选或者修改之后的标签都会被保存出来
    
    Parameters
    ----------
    xmldir : (str)
        xml文件所在目录,会搜索子目录
    
    savedir : (str)
        筛选和修改之后的数据存放的目录,会将修改之后的标签和图片拷贝过来
    
    classes : (tuple = None)
        指定了 哪些标签会被保留,若为None,则所有标签都会保留
    
    changelabel_dict : (dict = None)
        键表示需要修改的标签名,值修改之后的标签名,默认为None就是所有标签不修改

    work_num: (int = 64)
        最大线程数量
    '''
    xmls_path=get_all_xml_path(xml_dir)
    with futures.ThreadPoolExecutor(work_num) as pool:
        task_list=(pool.submit(deal_xml,xml_path,savedir,classes,changelabel_dict) for xml_path in xmls_path)
        for x in tqdm(futures.as_completed(task_list),total=len(xmls_path)):
            pass




if __name__=='__main__':
    with timeblock('time'):
        xml_dir='/home/gpu-server/project/data/game/bj_data'
        savedir='/home/gpu-server/project/detection_test/dataset/0703_bujian'
        classes=('bj','jsxs','sly_dmyw','yw_nc','yw_gkxfw','byq_hxq','sly_bjbmyw','jyz')
        changelabel_dict={
                'dlq_dzx':'xtsb',
                'xhzz_znx':'xtsb',
                'byq_dzx':'xtsb',
                'dz_czjgx':'xtsb',
                'dlq_czx':'xtsb',
                'zhdq_jgx':'xtsb',
                'byq_tg':'jyz',
                'cqtg_bt':'jyz',
                'zhdq_tg':'jyz'
        }
        main(xml_dir,savedir,classes,changelabel_dict=changelabel_dict)

协程版本

使用说明 : 1、本程序可以作什么? 本程序功能分为两个部分: 一部分是:批量查找指定路径下的文件中的内容中,是否包含要查询的项目。并把查询的内容分文件存储。 一部分是:将文本文件导入EXCEL中,可以将上一步查找的结果导入,也可以自己选择文件导入(支持多选)。 2、如何使用他? a、批量查找: 首先,在“读入位置”按钮处设置你所要读取的文件的存放位置,此时程序会自动读入此文件夹下所有文件,以供选择;其次,用“>>”或“>”将要读取的文件选入读取队列,当然如果选错了可以用“<<”或“<”删除队列,或者鼠标双击选中项删除;再次,点击“+”按钮,添加查找项目到查找项目列表,一次只可以添加一条,如需添加多条则需要重复添加操作;如果添加错误可以双击选中项删除或选中后点击“-”按钮。最后,点击“开始查找”,程序将会把查找结果输保存到指定路径下面的output文件夹下面,你可以选择是否打开目录查看。如果需要查询的文件有文件头,可以选择“保留文件首行”。 b、EXCEL导入: 首先,选择导入方式,导入方式分为“查询结果导入”和“新选文件导入”两种;当选择“查询结果导入”时,本程序将把“读入位置”处“output”文件夹下文件批量导入EXCEL。当选择“新选文件导入”时,本程序在点击“开始导入”时将弹窗口,您可以自己选择需要导入文件(支持多选),导入EXCEL。其次,设置导入文件时的分割符,默认为“|”,本程序只支持按照分隔符导入。最后,点击“开始导入”按钮开始导入。 3、本程序不判断所查找的文件类型 由于本程序在读入文件时,并没有校验文件的内容和文件类型,因此本程序会读取用户所选择的任意文件,即使此文件是二进制格式的。不论是查询或者是导入功能都是这样。本程序将按行读取所选择的文件(或者有换行符的),在读取完文件后,无论是否找到,都会创建和源文件相同类型的文件,即使是.exe或.rar(一般是打不开的),文件名存储为“output”+原文件名。即使没有找到任何相匹配的内容,本程序也会创建文件,这时后文件大小是0字节,可以按照大小排列看到。 4、请使用“清除文件”按钮及时清除查询结果 程序在查询和创建文件的过程中,不会判断是否已经执行过查询操作。如果已经执行过查询操作,“output”文件夹下就会存在查询的文件,当再次执行查询时,本程序会在已存在的文件后追加查询结果。这样就会现重复的记录或内容。因此,当需要多次查询时,每次查询前需要点击“清空文件”按钮删除output文件夹,才能保证查找的准确。 5、本程序不会判断运行的环境,因此在运行过程中可能会有些未知的错误 本程序在win7环境,vs2012,Netframe4.0下编译通过。本程序支持winxp及以上操作系统。执行EXCEL导入的时候,需要安装Office。Office的版本在2003以上就可以。但是不同我Office版本对导入性能,有一定的影响。Excel2003,最多256列,即2的8次方,最多65536行,即2放入16次方; Excel2007及以上版本,最多16384列,即2的14次方,最多1048576行,即2的20次方。因此如果需要导入的单个文件的行数或者列数,超过了所安装Office版本的最多行列数,程序将会报错!
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值