网上查了一下大多都是将目标裁剪出来,而很少有对标准格式的数据集进行裁剪的脚本,参照其他博主的代码,改了一份裁剪数据集的代码。其中标签格式为VOC格式。
import cv2
import os
import sys
import numpy as np
import glob
from multiprocessing import Pool
from functools import partial
import xml.etree.ElementTree as ET
from xml.dom.minidom import Document
from lxml import etree
def iou(BBGT, imgRect):
left_top = np.maximum(BBGT[:, :2], imgRect[:2])
right_bottom = np.minimum(BBGT[:, 2:], imgRect[2:])
wh = np.maximum(right_bottom-left_top, 0)
inter_area = wh[:, 0]*wh[:, 1]
iou = inter_area/((BBGT[:, 2]-BBGT[:, 0])*(BBGT[:, 3]-BBGT[:, 1]))
BB = np.concatenate((left_top, right_bottom), axis=1)
return iou, BB
def get_bbox(xml_path):
BBGT = []
tree = ET.parse(xml_path)
root = tree.getroot()
for obj in root.iter('object'):
difficult = obj.find('difficult').text
cls = obj.find('name').text
xmlbox = obj.find('bndbox')
xmin = int(xmlbox.find('xmin').text)
ymin = int(xmlbox.find('ymin').text)
xmax = int(xmlbox.find('xmax').text)
ymax = int(xmlbox.find('ymax').text)
label = cls
BBGT.append([xmin, ymin, xmax, ymax, label])
return np.array(BBGT)
def split(imgname, dirsrc, dirdst, subsize=800, gap=200, iou_thresh=0.3, ext='.png'):
img = cv2.imread(os.path.join(os.path.join(dirsrc,'JPEGImages'), imgname), -1)
xml_path = os.path.join(os.path.join(dirsrc, 'Anotations'), imgname.split('.')[0]+'.xml')
BBGT = get_bbox(xml_path)
img_h,img_w = img.shape[:2]
top = 0 # 图片上方标量
reachbottom = False
while not reachbottom:
reachright = False
left = 0
if top + subsize >= img_h:
reachbottom = True
top = max(img_h-subsize,0)
while not reachright:
if left + subsize >= img_w:
reachright = True
left = max(img_w-subsize,0)
imgsplit = img[top:min(top+subsize,img_h),left:min(left+subsize,img_w)]
if imgsplit.shape[:2] != (subsize,subsize):
try:
template = np.zeros((subsize, subsize, imgsplit.shape[2]), dtype=np.uint8)
template[0:imgsplit.shape[0], 0:imgsplit.shape[1]] = imgsplit
imgsplit = template
except:
template = np.zeros((subsize,subsize),dtype=np.uint8)
template[0:imgsplit.shape[0],0:imgsplit.shape[1]] = imgsplit
imgsplit = template
imgrect = np.array([left,top,min(left+subsize,img_w),min(top+subsize,img_h)]).astype('float32')
ious, X = iou(BBGT[:,:4].astype('float32'), imgrect)
BB = np.concatenate((X, BBGT[:, 4:]), axis = 1)
BBpatch = BB[ious > iou_thresh]
## abandaon images with 0 bboxes
if len(BBpatch) > 0:
# print(len(BBpatch))
cv2.imwrite(os.path.join(os.path.join(dirdst, 'JPEGImages'),
imgname.split('.')[0] + '_' + str(left) + '_' + str(top) + ext), imgsplit)
xml = os.path.join(os.path.join(dirdst, 'Anotations'),
imgname.split('.')[0] + '_' + str(left) + '_' + str(top) + '.xml')
ann = GEN_Annotations(dirsrc)
try:
ann.set_size(imgsplit.shape[0], imgsplit.shape[1], imgsplit.shape[2])
except:
ann.set_size(imgsplit.shape[0], imgsplit.shape[1], 1)
for bb in BBpatch:
x1, y1, x2, y2, target_id = int(float(bb[0])) - left, int(float(bb[1])) - top, int(float(bb[2])) - left, int(float(bb[3])) - top, bb[4]
# target_id, x1, y1, x2, y2 = anno_info
label_name = target_id
ann.add_pic_attr(label_name, x1, y1, x2, y2)
ann.savefile(xml)
left += subsize-gap
top += subsize-gap
class GEN_Annotations:
def __init__(self, filename):
self.root = etree.Element("annotation")
child1 = etree.SubElement(self.root, "folder")
child1.text = "VOC2007"
child2 = etree.SubElement(self.root, "filename")
child2.text = filename
child3 = etree.SubElement(self.root, "source")
child4 = etree.SubElement(child3, "annotation")
child4.text = "PASCAL VOC2007"
child5 = etree.SubElement(child3, "database")
child5.text = "Unknown"
def set_size(self, witdh, height, channel):
size = etree.SubElement(self.root, "size")
widthn = etree.SubElement(size, "width")
widthn.text = str(witdh)
heightn = etree.SubElement(size, "height")
heightn.text = str(height)
channeln = etree.SubElement(size, "depth")
channeln.text = str(channel)
def savefile(self, filename):
tree = etree.ElementTree(self.root)
tree.write(filename, pretty_print=True, xml_declaration=False, encoding='utf-8')
def add_pic_attr(self, label, xmin, ymin, xmax, ymax):
object = etree.SubElement(self.root, "object")
namen = etree.SubElement(object, "name")
namen.text = label
bndbox = etree.SubElement(object, "bndbox")
xminn = etree.SubElement(bndbox, "xmin")
xminn.text = str(xmin)
yminn = etree.SubElement(bndbox, "ymin")
yminn.text = str(ymin)
xmaxn = etree.SubElement(bndbox, "xmax")
xmaxn.text = str(xmax)
ymaxn = etree.SubElement(bndbox, "ymax")
ymaxn.text = str(ymax)
if __name__ == '__main__':
import tqdm
dirsrc= r'C:\Users\LazyShark\Desktop\data_RZB_split\data'
dirdst= dirsrc + '//' + 'data_crop'
if not os.path.exists(dirdst):
os.mkdir(dirdst)
if not os.path.exists(os.path.join(dirdst, 'JPEGImages')):
os.mkdir(os.path.join(dirdst, 'JPEGImages'))
if not os.path.exists(os.path.join(dirdst, 'Anotations')):
os.mkdir(os.path.join(dirdst, 'Anotations'))
subsize = 512
gap = 0
iou_thresh = 0.35
ext = '.jpg'
imglist = glob.glob(f'{dirsrc}/JPEGImages/*.jpg')
imgnameList = [os.path.split(imgpath)[-1] for imgpath in imglist]
for imgname in tqdm.tqdm(imgnameList):
split(imgname, dirsrc, dirdst, subsize, gap, iou_thresh, ext)
其中文件目录为
data
--Anotations
----1.xml
----2.xml
--JPEGImages
----1.jpg
----2.jpg