目标检测数据集裁剪代码

网上查了一下大多都是将目标裁剪出来,而很少有对标准格式的数据集进行裁剪的脚本,参照其他博主的代码,改了一份裁剪数据集的代码。其中标签格式为VOC格式。

import cv2
import os
import sys
import numpy as np
import glob
from multiprocessing import Pool
from functools import partial
import xml.etree.ElementTree as ET
from xml.dom.minidom import Document
from lxml import etree


def iou(BBGT, imgRect):
    left_top = np.maximum(BBGT[:, :2], imgRect[:2]) 
    right_bottom = np.minimum(BBGT[:, 2:], imgRect[2:]) 
    wh = np.maximum(right_bottom-left_top, 0) 
    inter_area = wh[:, 0]*wh[:, 1] 
    iou = inter_area/((BBGT[:, 2]-BBGT[:, 0])*(BBGT[:, 3]-BBGT[:, 1])) 
    BB = np.concatenate((left_top, right_bottom), axis=1) 
    return iou, BB


def get_bbox(xml_path):
    BBGT = []
    tree = ET.parse(xml_path)  
    root = tree.getroot()  
    for obj in root.iter('object'): 
        difficult = obj.find('difficult').text  
        cls = obj.find('name').text  
        xmlbox = obj.find('bndbox') 
        xmin = int(xmlbox.find('xmin').text)
        ymin = int(xmlbox.find('ymin').text)
        xmax = int(xmlbox.find('xmax').text)
        ymax = int(xmlbox.find('ymax').text)
        label = cls
        BBGT.append([xmin, ymin, xmax, ymax, label])
    return np.array(BBGT)

def split(imgname, dirsrc, dirdst, subsize=800, gap=200, iou_thresh=0.3, ext='.png'):
    img = cv2.imread(os.path.join(os.path.join(dirsrc,'JPEGImages'), imgname), -1) 
    xml_path = os.path.join(os.path.join(dirsrc, 'Anotations'), imgname.split('.')[0]+'.xml') 
    BBGT = get_bbox(xml_path)

    img_h,img_w = img.shape[:2] 
    top = 0  # 图片上方标量
    reachbottom = False
    while not reachbottom: 
        reachright = False 
        left = 0 
        if top + subsize >= img_h: 
            reachbottom = True
            top = max(img_h-subsize,0)
        while not reachright:
            if left + subsize >= img_w: 
                reachright = True
                left = max(img_w-subsize,0)
            imgsplit = img[top:min(top+subsize,img_h),left:min(left+subsize,img_w)]
            if imgsplit.shape[:2] != (subsize,subsize):
                try: 
                    template = np.zeros((subsize, subsize, imgsplit.shape[2]), dtype=np.uint8)
                    template[0:imgsplit.shape[0], 0:imgsplit.shape[1]] = imgsplit
                    imgsplit = template
                except: 
                    template = np.zeros((subsize,subsize),dtype=np.uint8)
                    template[0:imgsplit.shape[0],0:imgsplit.shape[1]] = imgsplit
                    imgsplit = template
            imgrect = np.array([left,top,min(left+subsize,img_w),min(top+subsize,img_h)]).astype('float32')
            ious, X = iou(BBGT[:,:4].astype('float32'), imgrect) 
            BB = np.concatenate((X, BBGT[:, 4:]), axis = 1) 
            BBpatch = BB[ious > iou_thresh] 
            ## abandaon images with 0 bboxes
            if len(BBpatch) > 0:
                # print(len(BBpatch))
                cv2.imwrite(os.path.join(os.path.join(dirdst, 'JPEGImages'),
                                         imgname.split('.')[0] + '_' + str(left) + '_' + str(top) + ext), imgsplit) 
                xml = os.path.join(os.path.join(dirdst, 'Anotations'),
                                        imgname.split('.')[0] + '_' + str(left) + '_' + str(top) + '.xml') 
                ann = GEN_Annotations(dirsrc)
                try:
                    ann.set_size(imgsplit.shape[0], imgsplit.shape[1], imgsplit.shape[2])  
                except:
                    ann.set_size(imgsplit.shape[0], imgsplit.shape[1], 1)
                for bb in BBpatch:
                    x1, y1, x2, y2, target_id = int(float(bb[0])) - left, int(float(bb[1])) - top, int(float(bb[2])) - left, int(float(bb[3])) - top, bb[4]  
                    # target_id, x1, y1, x2, y2 = anno_info
                    label_name = target_id  
                    ann.add_pic_attr(label_name, x1, y1, x2, y2)  
                ann.savefile(xml)  
            left += subsize-gap  
        top += subsize-gap  


class GEN_Annotations:
    def __init__(self, filename):
        self.root = etree.Element("annotation")

        child1 = etree.SubElement(self.root, "folder")
        child1.text = "VOC2007"

        child2 = etree.SubElement(self.root, "filename")
        child2.text = filename

        child3 = etree.SubElement(self.root, "source")

        child4 = etree.SubElement(child3, "annotation")
        child4.text = "PASCAL VOC2007"
        child5 = etree.SubElement(child3, "database")
        child5.text = "Unknown"

    def set_size(self, witdh, height, channel):
        size = etree.SubElement(self.root, "size")
        widthn = etree.SubElement(size, "width")
        widthn.text = str(witdh)
        heightn = etree.SubElement(size, "height")
        heightn.text = str(height)
        channeln = etree.SubElement(size, "depth")
        channeln.text = str(channel)

    def savefile(self, filename):
        tree = etree.ElementTree(self.root)
        tree.write(filename, pretty_print=True, xml_declaration=False, encoding='utf-8')

    def add_pic_attr(self, label, xmin, ymin, xmax, ymax):
        object = etree.SubElement(self.root, "object")
        namen = etree.SubElement(object, "name")
        namen.text = label
        bndbox = etree.SubElement(object, "bndbox")
        xminn = etree.SubElement(bndbox, "xmin")
        xminn.text = str(xmin)
        yminn = etree.SubElement(bndbox, "ymin")
        yminn.text = str(ymin)
        xmaxn = etree.SubElement(bndbox, "xmax")
        xmaxn.text = str(xmax)
        ymaxn = etree.SubElement(bndbox, "ymax")
        ymaxn.text = str(ymax)

if __name__ == '__main__':
    import tqdm
    dirsrc= r'C:\Users\LazyShark\Desktop\data_RZB_split\data' 
    dirdst= dirsrc + '//' + 'data_crop'
    if not os.path.exists(dirdst):
        os.mkdir(dirdst)
    if not os.path.exists(os.path.join(dirdst, 'JPEGImages')):
        os.mkdir(os.path.join(dirdst, 'JPEGImages'))
    if not os.path.exists(os.path.join(dirdst, 'Anotations')):
        os.mkdir(os.path.join(dirdst, 'Anotations'))


    subsize = 512
    gap = 0 
    iou_thresh = 0.35 
    ext = '.jpg'

    imglist = glob.glob(f'{dirsrc}/JPEGImages/*.jpg')
    imgnameList = [os.path.split(imgpath)[-1] for imgpath in imglist]
    for imgname in tqdm.tqdm(imgnameList):
        split(imgname, dirsrc, dirdst, subsize, gap, iou_thresh, ext)

​

其中文件目录为

data

--Anotations

----1.xml

----2.xml

--JPEGImages

----1.jpg

----2.jpg

代码参照基于大遥感影像场景的目标检测数据集的裁剪方法(水平框)_遥感目标检测裁剪-CSDN博客

  • 11
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值