Imagenet标注文件的Read和Write

image_label_util.py

#coding:utf-8
import os, cv2, shutil, random, codecs, HTMLParser
from lxml import etree
from lxml.etree import Element, SubElement, tostring

class PicAnno:
    objects = []

    def __init__(self, folder):
        self.objects = []
        self.folder = folder

    def set_folder(self, folder):
        self.folder = folder

    def set_filename(self, filename):
        self.filename = filename

    def set_size(self, width, height, depth):
        self.width = width
        self.height = height
        self.depth = depth

    def add_object(self, object):
        self.objects.append(object)


class PicObject:
    def __init__(self, name):
        self.name = name

    def set_name(self, name):
        self.name = name

    def set_pose(self, pose):
        self.pose = pose

    def set_truncated(self, truncated):
        self.truncated = truncated

    def set_difficult(self, difficult):
        self.difficult = difficult

    def set_bndbox(self, xmin, ymin, xmax, ymax):
        self.xmin = xmin
        self.ymin = ymin
        self.xmax = xmax
        self.ymax = ymax


class VocUtil:
    def read_anno_xml(self, xml_path):
        tree = etree.parse(xml_path)
        root = tree.getroot()

        # gbk
        # cmt = ''.join(codecs.open(xml_path, 'r', 'gbk').readlines())
        # root = etree.fromstring(cmt)

        picAnno = PicAnno(root.xpath('/annotation/folder')[0].text)
        picAnno.set_filename(root.xpath('/annotation/filename')[0].text)
        picAnno.set_size(root.xpath('/annotation/size/width')[0].text,
                         root.xpath('/annotation/size/height')[0].text,
                         root.xpath('/annotation/size/depth')[0].text)
        for obj in root.xpath('/annotation/object'):
            picObject = PicObject(obj.xpath('name')[0].text)
            picObject.set_pose(obj.xpath('pose')[0].text)
            picObject.set_truncated(obj.xpath('truncated')[0].text)
            picObject.set_difficult(obj.xpath('difficult')[0].text)
            picObject.set_bndbox(obj.xpath('bndbox/xmin')[0].text,
                                 obj.xpath('bndbox/ymin')[0].text,
                                 obj.xpath('bndbox/xmax')[0].text,
                                 obj.xpath('bndbox/ymax')[0].text)
            picAnno.add_object(picObject)
        return picAnno

    def parse_anno_xml(self, picAnno):
        node_root = Element('annotation')
        node_folder = SubElement(node_root, 'folder')

        if hasattr(picAnno, 'folder') and picAnno.folder is not None:
            node_folder.text = picAnno.folder

        node_filename = SubElement(node_root, 'filename')
        if hasattr(picAnno, 'filename') and picAnno.filename is not None:
            node_filename.text = picAnno.filename

        node_size = SubElement(node_root, 'size')
        node_width = SubElement(node_size, 'width')
        if hasattr(picAnno, 'width') and picAnno.width is not None:
            node_width.text = str(picAnno.width)

        node_height = SubElement(node_size, 'height')
        if hasattr(picAnno, 'height') and picAnno.height is not None:
            node_height.text = str(picAnno.height)

        node_depth = SubElement(node_size, 'depth')
        if picAnno.depth is not None:
            node_depth.text = str(picAnno.depth)

        if len(picAnno.objects) > 0:
            for obj in picAnno.objects:
                node_object = SubElement(node_root, 'object')
                node_name = SubElement(node_object, 'name')
                if hasattr(obj, 'name') and obj.name is not None:
                    node_name.text = obj.name
                node_pose = SubElement(node_object, 'pose')
                if hasattr(obj, 'pose') and obj.pose is not None:
                    node_pose.text = str(obj.pose)
                node_truncated = SubElement(node_object, 'truncated')
                if hasattr(obj, 'truncated') and obj.truncated is not None:
                    node_truncated.text = str(obj.truncated)
                node_difficult = SubElement(node_object, 'difficult')
                if hasattr(obj, 'difficult') and obj.difficult is not None:
                    node_difficult.text = str(obj.difficult)
                node_bndbox = SubElement(node_object, 'bndbox')
                node_xmin = SubElement(node_bndbox, 'xmin')
                if hasattr(obj, 'xmin') and obj.xmin is not None:
                    node_xmin.text = str(obj.xmin)
                node_ymin = SubElement(node_bndbox, 'ymin')
                if hasattr(obj, 'ymin') and obj.ymin is not None:
                    node_ymin.text = str(obj.ymin)
                node_xmax = SubElement(node_bndbox, 'xmax')
                if hasattr(obj, 'xmax') and obj.xmax is not None:
                    node_xmax.text = str(obj.xmax)
                node_ymax = SubElement(node_bndbox, 'ymax')
                if hasattr(obj, 'ymax') and obj.ymax is not None:
                    node_ymax.text = str(obj.ymax)

        xml = tostring(node_root, pretty_print=True)
        # xml_txt = str(xml,encoding='utf-8')  #window
        xml_txt = str(xml).encode('utf-8')   #linux
        xml_txt = HTMLParser.HTMLParser().unescape(xml_txt)
        return xml_txt

    def save_anno_xml(self, xml_path, xml_text):
        with codecs.open(xml_path, 'w', 'utf-8') as f:
            f.write(xml_text)


    def readFile(self, path):
        file = open(path, 'r')
        lines = [line.strip() for line in file.readlines()]
        file.close()
        return lines

    def writeLines(self,file_path, lines):
        file_dir = os.path.dirname(file_path)
        if not os.path.exists(file_dir):
            os.makedirs(file_dir)
        fr = open(file_path, 'w')
        for line in lines:
            fr.write(line.strip() + '\n')
        fr.close()

    def gene_train_test_val_txt(self,anno_dir,txt_dir):
        pic_names = [pic_name.split('.')[0] for pic_name in os.listdir(anno_dir) if pic_name.endswith('.xml')]
        random.shuffle(pic_names)
        self.writeLines(os.path.join(txt_dir, 'test.txt'), pic_names)
        random.shuffle(pic_names)
        self.writeLines(os.path.join(txt_dir, 'train.txt'), pic_names)
        random.shuffle(pic_names)
        self.writeLines(os.path.join(txt_dir, 'trainval.txt'), pic_names)
        random.shuffle(pic_names)
        self.writeLines(os.path.join(txt_dir, 'val.txt'), pic_names)
        print('生成测试集、训练集、训练验证集、验证集完成!')
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值