自用:完成Voc标签标注后提取相关图像,并划分训练集和验证集

"""
--function  1.通过在xml_base中的xml将img_base中每个类图像复制到shutil_copy_to
            2.在txt_base生成train.txt,val.txt文件
--author    图像组——***
--time      2024.5.9
--note      完成Voc标签标注后,将annotation从图像文件夹中拿出来,再根据annotation将每个打过标签的图像拿出来,并按照10%作为验证集划分训练集和验证集
"""
import os
import shutil
from lxml import etree


def parse_xml_to_dict(xml):
    """
    将xml文件解析成字典形式,参考tensorflow的recursive_parse_xml_to_dict
    Args:
        xml: xml tree obtained by parsing XML file contents using lxml.etree

    Returns:
        Python dictionary holding XML contents.
    """

    if len(xml) == 0:  # 遍历到底层,直接返回tag对应的信息
        return {xml.tag: xml.text}

    result = {}
    for child in xml:
        child_result = parse_xml_to_dict(child)  # 递归遍历标签信息
        if child.tag != 'object':
            result[child.tag] = child_result[child.tag]
        else:
            if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                result[child.tag] = []
            result[child.tag].append(child_result[child.tag])
    return {xml.tag: result}


def findAllxml(base):
    for root, ds, fs in os.walk(base):
        for f in fs:
            # read xml
            with open(os.path.join(root, f)) as fid:
                xml_str = fid.read()
            xml = etree.fromstring(xml_str)
            data = parse_xml_to_dict(xml)["annotation"]
            img_name = data["filename"]
            yield img_name, f.split(".")[0]


def findAllFile(base):
    for root, ds, fs in os.walk(base):
        for f in fs:
            yield (os.path.join(root, f), f)


def main():
    xml_base = r'E:\ProjectImage\Dataset\Annotations'
    img_base = r'E:\ProjectImage\Dataset\c_medcien'
    shutil_copy_to = r"E:\ProjectImage\Dataset\cm55\JPEGImages"
    txt_base = r"E:\ProjectImage\Dataset\cm55\ImageSets\Main"

    xml_list = []
    train_list = []
    eval_list = []
    all_path = []
    all_filename = []

    for (path, filename) in findAllFile(img_base):
        all_path.append(path)
        all_filename.append(filename)

    for img_name, xml_name in findAllxml(xml_base):
        idx = all_filename.index(img_name)
        shutil.copy(all_path[idx], shutil_copy_to)
        xml_list.append(xml_name)

    # 训练集:验证集 = 10:1,每十张取一张
    for idx, xml_name in enumerate(xml_list):
        if idx % 10 == 0:
            eval_list.append(xml_name)
        else:
            train_list.append(xml_name)

    f_train = open(os.path.join(txt_base, "train.txt"), "w")
    for line in train_list:
        f_train.write(line + '\n')
    f_train.close()

    f_eval = open(os.path.join(txt_base, "val.txt"), "w")
    for line in eval_list:
        f_eval.write(line + '\n')
    f_eval.close()

if __name__ == '__main__':
    main()

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值