【数据集处理】生成voc格式数据集


import json
import re
import os
import xml.dom.minidom
import cv2
import xml.etree.ElementTree as ET
import random

txt_path = "jsonFile.txt"
image_Path = "./VOC2007/JPEGImages/"
voc_classes_path = "./voc_classes.txt"

_ANNOTATION_SAVE_PATH = './VOC2007/Annotations'
# 想把xml文件存在哪里就写哪里



# xml文件规范定义
_INDENT = '' * 4
_NEW_LINE = '\n'
_FOLDER_NODE = 'WHUT2022'
_ROOT_NODE = 'annotation'
_DATABASE_NAME = 'Dection'
_ANNOTATION = 'WHUT2022'
_AUTHOR = 'CTX'
_SEGMENTED = '0'
_DIFFICULT = '0'
_TRUNCATED = '0'
_POSE = 'Unspecified'




#--------------------------------------------------------------------------------------------------------------------------------#
#   annotation_mode用于指定该文件运行时计算的内容
#   annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt
#   annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt
#   annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt
#--------------------------------------------------------------------------------------------------------------------------------#
annotation_mode     = 0
#-------------------------------------------------------------------#
#   必须要修改,用于生成2007_train.txt、2007_val.txt的目标信息
#   与训练和预测所用的classes_path一致即可
#   如果生成的2007_train.txt里面没有目标信息
#   那么就是因为classes没有设定正确
#   仅在annotation_mode为0和2的时候有效
#-------------------------------------------------------------------#
classes_path        = "./voc_classes.txt"
#--------------------------------------------------------------------------------------------------------------------------------#
#   trainval_percent用于指定(训练集+验证集)与测试集的比例,默认情况下 (训练集+验证集):测试集 = 9:1
#   train_percent用于指定(训练集+验证集)中训练集与验证集的比例,默认情况下 训练集:验证集 = 9:1
#   仅在annotation_mode为0和1的时候有效
#--------------------------------------------------------------------------------------------------------------------------------#
trainval_percent    = 0.95
train_percent       = 0.95
#-------------------------------------------------------#
#   指向VOC数据集所在的文件夹
#   默认指向根目录下的VOC数据集
#-------------------------------------------------------#
# VOCdevkit_path  = 'VOCdevkit'
VOCdevkit_path = './'






# 借鉴Fast-rcnn
def convert_annotation(year, image_id, list_file):
    in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml'%(year, image_id)))
    xml_text = in_file.read()
    root = ET.fromstring(xml_text)

    with open(classes_path, encoding='utf-8') as f:
        class_names = f.readlines()
    classes = [c.strip() for c in class_names]

    for obj in root.iter('object'):
        difficult = 0
        if obj.find('difficult')!=None:
            difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult)==1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
        list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))


# 借鉴Fast-rcnn
def get_txt(annotation_mode,trainval_percent,train_percent):
    random.seed(0)
    VOCdevkit_sets = [('2007', 'train'), ('2007', 'val')]
    if annotation_mode == 0 or annotation_mode == 1:
        print("Generate txt in ImageSets.")
        xmlfilepath = os.path.join(VOCdevkit_path, 'VOC2007/Annotations')
        saveBasePath = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main')
        temp_xml = os.listdir(xmlfilepath)
        total_xml = []
        for xml in temp_xml:
            if xml.endswith(".xml"):
                total_xml.append(xml)

        num = len(total_xml)
        list = range(num)
        tv = int(num * trainval_percent)
        tr = int(tv * train_percent)
        trainval = random.sample(list, tv)
        train = random.sample(trainval, tr)

        print("train and val size", tv)
        print("train size", tr)
        ftrainval = open(os.path.join(saveBasePath, 'trainval.txt'), 'w')
        ftest = open(os.path.join(saveBasePath, 'test.txt'), 'w')
        ftrain = open(os.path.join(saveBasePath, 'train.txt'), 'w')
        fval = open(os.path.join(saveBasePath, 'val.txt'), 'w')

        for i in list:
            name = total_xml[i][:-4] + '\n'
            if i in trainval:
                ftrainval.write(name)
                if i in train:
                    ftrain.write(name)
                else:
                    fval.write(name)
            else:
                ftest.write(name)

        ftrainval.close()
        ftrain.close()
        fval.close()
        ftest.close()
        print("Generate txt in ImageSets done.")

    if annotation_mode == 0 or annotation_mode == 2:
        print("Generate 2007_train.txt and 2007_val.txt for train.")
        for year, image_set in VOCdevkit_sets:
            image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt' % (year, image_set)),
                             encoding='utf-8').read().strip().split()
            list_file = open('%s_%s.txt' % (year, image_set), 'w', encoding='utf-8')
            for image_id in image_ids:
                list_file.write('%s/VOC%s/JPEGImages/%s.jpg' % (os.path.abspath(VOCdevkit_path), year, image_id))

                convert_annotation(year, image_id, list_file)
                list_file.write('\n')
            list_file.close()
        print("Generate 2007_train.txt and 2007_val.txt for train done.")


# 获取所有图像名
# for file in os.listdir(image_Path):
#     name = os.path.splitext(file)


# 得到一个json组成的列表,并把类别名称放在voc_classes_path里
def getJsonAndClasses(txt_path,voc_classes_path):
    clsname = set()  # 所有类别的集合
    inputfile = []  # 所有的目标信息的字典形式组成的列表

    with open(txt_path, "r", encoding='UTF-8') as f:  # 修改key的名字
        txt_all = f.read()
        print(type(txt_all))
        dic_all = json.loads(txt_all)
        for k, v in dic_all.items():
            for i in v:
                classes = re.split('-+|NAN', i['regionClass'])
                for cls in classes:
                    if cls == '' or cls[0] == "请":
                        continue
                    clsname.add(cls)
                    inner = {
                        "filename": k,
                        "class": cls,
                        "bndbox": [i['xmin'], i['xmax'], i['ymin'], i['ymax']]
                    }
                    inputfile.append(inner)

    clsname = list(clsname)
    file = open(voc_classes_path, 'w',encoding='UTF-8')

    for i in range(len(clsname)):
        s = str(clsname[i])
        file.write(s)
        file.write("\n")

    return inputfile

# 封装创建节点的
def createElementNode(doc, tag, attr):
    element_node = doc.createElement(tag)
    text_node = doc.createTextNode(attr)
    element_node.appendChild(text_node)
    return element_node

# 封装添加一个子节点的
def createChildNode(doc, tag, attr, parent_node):
    child_node = createElementNode(doc, tag, attr)
    parent_node.appendChild(child_node)

# object节点比较特殊
def createObjectNode(doc, attrs):
    object_node = doc.createElement('object')
    createChildNode(doc, 'name', attrs['class'],
                    object_node)
    createChildNode(doc, 'pose',
                    _POSE, object_node)
    createChildNode(doc, 'truncated',
                    _TRUNCATED, object_node)
    createChildNode(doc, 'difficult',
                    _DIFFICULT, object_node)
    bndbox_node = doc.createElement('bndbox')
    createChildNode(doc, 'xmin', str(int(attrs['bndbox'][0])),
                    bndbox_node)
    createChildNode(doc, 'ymin', str(int(attrs['bndbox'][2])),
                    bndbox_node)
    createChildNode(doc, 'xmax', str(int(attrs['bndbox'][1])),
                    bndbox_node)
    createChildNode(doc, 'ymax', str(int(attrs['bndbox'][3])),
                    bndbox_node)
    object_node.appendChild(bndbox_node)

    return object_node


# 将documentElement写入XML文件
def writeXMLFile(doc, filename):
    tmpfile = open('tmp.xml', 'w')
    doc.writexml(tmpfile, addindent='' * 4, newl='\n', encoding='utf-8')
    tmpfile.close()
    # 删除第一行默认添加的标记
    fin = open('tmp.xml')
    # print(filename)
    fout = open(filename, 'w')
    # print(os.path.dirname(fout))

    lines = fin.readlines()

    for line in lines[1:]:

        if line.split():
            fout.writelines(line)

    fin.close()
    fout.close()





if __name__ == "__main__":
    ##读取图片列表
    img_path = image_Path
    fileList = os.listdir(img_path)
    if fileList == 0:
        os._exit(-1)


    # 第一步获得json列表

    ann_data = getJsonAndClasses(txt_path,voc_classes_path)

    current_dirpath = os.path.dirname(os.path.abspath('__file__'))

    if not os.path.exists(_ANNOTATION_SAVE_PATH):
        os.mkdir(_ANNOTATION_SAVE_PATH)

    # if not os.path.exists(_IMAGE_COPY_PATH):
    #     os.mkdir(_IMAGE_COPY_PATH)

    for imageName in fileList:

        saveName = imageName.strip(".jpg")
        print(saveName)

        xml_file_name = os.path.join(_ANNOTATION_SAVE_PATH, (saveName + '.xml'))

        # 获得图像
        img = cv2.imread(os.path.join(img_path, imageName))
        print(os.path.join(img_path, imageName))
        # cv2.imshow(img)
        height, width, channel = img.shape
        print(height, width, channel)


        my_dom = xml.dom.getDOMImplementation()

        doc = my_dom.createDocument(None, _ROOT_NODE, None)

        # 获得根节点
        root_node = doc.documentElement

        # folder节点

        createChildNode(doc, 'folder', _FOLDER_NODE, root_node)

        # filename节点

        createChildNode(doc, 'filename', saveName + '.jpg', root_node)

        # source节点

        source_node = doc.createElement('source')

        # source的子节点

        createChildNode(doc, 'database', _DATABASE_NAME, source_node)

        createChildNode(doc, 'annotation', _ANNOTATION, source_node)

        createChildNode(doc, 'image', 'flickr', source_node)

        createChildNode(doc, 'flickrid', 'NULL', source_node)

        root_node.appendChild(source_node)

        # owner节点

        owner_node = doc.createElement('owner')

        # owner的子节点

        createChildNode(doc, 'flickrid', 'NULL', owner_node)

        createChildNode(doc, 'name', _AUTHOR, owner_node)

        root_node.appendChild(owner_node)

        # size节点

        size_node = doc.createElement('size')

        createChildNode(doc, 'width', str(width), size_node)

        createChildNode(doc, 'height', str(height), size_node)

        createChildNode(doc, 'depth', str(channel), size_node)

        root_node.appendChild(size_node)

        # segmented节点

        createChildNode(doc, 'segmented', _SEGMENTED, root_node)

        for ann in ann_data:
            if (saveName == ann["filename"]):
                # object节点
                object_node = createObjectNode(doc, ann)
                root_node.appendChild(object_node)

            else:
                continue

        # 构建XML文件名称

        print(xml_file_name)

        writeXMLFile(doc, xml_file_name)

        get_txt(annotation_mode, trainval_percent, train_percent)

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值