图像增强,旋转,翻转,缩放
xml解析,增强处理,写入xml
import cv2
import os
import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage
from tqdm import tqdm
ia.seed(1)
import xml.etree.ElementTree as ET
from xml.etree.ElementTree import parse
# 解析xml文件
class XmlParse:
def __init__(self, xml_path, posfix='_after'):
self.xml_path = xml_path
self.posfix = posfix
self.xml_file = ET.parse(xml_path) # 读取xml文件
self.root = self.xml_file.getroot()
self.root_objects = self.root.findall('object') # 所有的目标
self.before_bbox_num = len(self.root_objects)
self.root_size = self.root.find('size')
w = int(self.root_size.find('width').text)
h = int(self.root_size.find('height').text)
d = int(self.root_size.find('depth').text)
self.size_before = (h, w, d)
def parseBBoxs(self):
"""
解析出xml文件中的所有bboxes
:return:
"""
self.before_bboxs = []
for per_obj in self.root_objects: # 遍历每一个目标
per_obj = per_obj.find('bndbox')
x1, y1 = int(per_obj.find('xmin').text), int(per_obj.find('ymin').text)
x2, y2 = int(per_obj.find('xmax').text), int(per_obj.find('ymax').text)
self.before_bboxs.append([x1, y1, x2, y2])
return self.before_bboxs
def reWrite_Save(self, after_image, after_bboxs, save_folder):
"""
将增强后的bboxes写入到xml文件,并另存为新文件。
注:增强后,目标框数量可能会减少,需删除多余框。图片尺寸可能会有影响,需同步修改。
:param after_bboxs:
:return:
"""
self.save_filename = os.path.basename(self.xml_path).replace(".xml", self.posfix + ".xml") # 保存xml的新文件名
self.save_path = os.path.join(save_folder, self.save_filename)
self.size_aug = after_image.shape
if self.size_aug != self.size_before:
# 尺寸和原图不一致,修改尺寸
h,w,d = self.size_aug
self.root_size.find('width').text = str(w)
self.root_size.find('height').text = str(h)
self.root_size.find('depth').text = str(d)
after_bbox_num = len(after_bboxs)
for idx_obj, per_obj in enumerate(self.root_objects):
if idx_obj < after_bbox_num:
# 修改标签的值
bbox = after_bboxs[idx_obj]
root_bndbox = per_obj.find('bndbox')
root_bndbox.find('xmin').text = str(int(bbox[0])) # 修改值
root_bndbox.find('ymin').text = str(int(bbox[1])) # 修改值
root_bndbox.find('xmax').text = str(int(bbox[2])) # 修改值
root_bndbox.find('ymax').text = str(int(bbox[3])) # 修改值
else:
# 删除多余的标签
self.root.remove(per_obj)
# 保存增强后的xml文件
self.xml_file.write(self.save_path, encoding="utf-8", xml_declaration=False)
def augment1(image=None, bboxs=None):
"""
:param image: ndarray(424, 700, 3)
:param bboxs: list[[x1,x2,y1,y2],[x1,x2,y1,y2], ... ,[x1,x2,y1,y2]]
:return:
"""
image = image
bbs = BoundingBoxesOnImage([
BoundingBox(bboxs[i][0], bboxs[i][1], bboxs[i][2], bboxs[i][3])
for i in range(len(bboxs))
], shape=image.shape)
# ========================== 增强处理步骤 可添加或删除 ============================
# https://imgaug.readthedocs.io/en/latest/source/api.html
seq = iaa.Sequential([
iaa.Multiply((1.2, 1.5)), # change brightness, doesn't affect BBs
iaa.Affine(
translate_px={"x": 40, "y": 60},
scale=(0.5, 0.7)
), # translate by 40/60px on x/y axis, and scale to 50-70%, affects BBs
iaa.Affine(rotate=45), # 旋转
iaa.blur.AverageBlur() # 模糊
])
# ==================================================================================
# Augment BBs and images.
image_aug, bbs_aug = seq(image=image, bounding_boxes=bbs)
# 可选,将超出边界的目标框机截断或去除。
bbs_aug = bbs_aug.remove_out_of_image().clip_out_of_image()
# print coordinates before/after augmentation (see below)
# use .x1_int, .y_int, ... to get integer coordinates
bboxes_aug = []
for i in range(len(bbs.bounding_boxes)):
after = bbs_aug.bounding_boxes[i]
bboxes_aug.append([after.x1, after.y1, after.x2, after.y2])
# image with BBs before/after augmentation (shown below)
image_before = bbs.draw_on_image(image, size=2)
image_after = bbs_aug.draw_on_image(image_aug, size=2, color=[0, 0, 255])
# cv2.imshow('win', image_after)
return image_aug, bboxes_aug
if __name__ == '__main__':
# ====================== 设置参数 ====================
# 原始图片存储文件夹
images_folder = 'images'
xml_folder = 'labels'
# 增强图片保存文件夹
save_img_folder = 'save/aug_imgs'
save_xml_folder = 'save/aug_labels'
# 后缀名 例:1id.jpg --> 1id_after.jpg
posfix = '_after'
img_files = os.listdir(images_folder)
xml_files = os.listdir(xml_folder)
img_paths = [os.path.join(images_folder, file) for file in img_files]
xml_paths = [os.path.join(xml_folder, file) for file in xml_files]
for (img_path, xml_path) in tqdm(zip(img_paths, xml_paths)):
assert os.path.basename(img_path).split(".")[0] == os.path.basename(xml_path).split(".")[0]
img_before = cv2.imread(img_path)
# 当前xml文件实例化
xml_instance = XmlParse(xml_path, posfix)
# 解析当前xml文件的boxes
bboxes_before = xml_instance.parseBBoxs()
# 数据增强,以列表形式返回bbox_aug
image_aug, bboxes_aug = augment1(img_before, bboxes_before)
# 保存图片
img_aug_filename = os.path.basename(img_path).replace(".jpg", posfix + ".jpg")
cv2.imwrite(os.path.join(save_img_folder, img_aug_filename), image_aug)
# 修改,保存xml文件
xml_instance.reWrite_Save(image_aug, bboxes_aug, save_xml_folder)