注:代码修改自另一位博主的文章,很抱歉找不到地址了。
做出以下修改:
1、把控制参数放在头部便于修改
2、原代码的增强前后图片和标注存放于一个文件夹,这里存放在新文件夹。
3、做出一些注释
注意:
1、visualize部分未修改,可能会因为我改了变量名跑不通
2、缩放方法可能造成标注偏移,增强之后需要检查
albumentations包下载
如果一直下载不了(time out),用清华镜像源:https://pypi.tuna.tsinghua.edu.cn/simple
pip install albumentations -i https://pypi.tuna.tsinghua.edu.cn/simple
import cv2
from matplotlib import pyplot as plt
import xml.etree.ElementTree as ET
import albumentations as A
import os
import time
# 控制参数
BOX_COLOR = (255, 0, 0) # Red
TEXT_COLOR = (255, 255, 255) # White
# 增强张数 original pictures size:62, then total size is 62*GENERATED_PICS_SIZE
GENERATED_PICS_SIZE = 20 # 增强方法在main Compose中修改
# 上级目录
DIR = "D:\\AI\\data6"
# 存储原图片的文件夹名,默认格式未jpg,如果为png需要自行修改
IMAGES_FILE = "images"
# 存储原xml标注的文件夹名
ANNOTATIONS_FILE = "annotations"
# 检查原本的xml标注,object中第几个为bndbox,从0开始计算
OBJ_NUM = 4 # 第五个
def visualize_bbox(img, bbox, class_name, color=BOX_COLOR, thickness=2):
"""Visualizes a single bounding box on the image"""
# x_min, y_min, w, h = bbox
# x_min, x_max, y_min, y_max = int(x_min), int(x_min + w), int(y_min), int(
# y_min + h)
x_min, y_min, x_max, y_max = bbox
print(x_min, y_min, x_max, y_max)
cv2.rectangle(img, (int(x_min), int(y_min)), (int(x_max), int(y_max)),
color=color, thickness=thickness)
((text_width, text_height), _) = cv2.getTextSize(class_name,
cv2.FONT_HERSHEY_SIMPLEX,
0.35, 1)
cv2.rectangle(img, (int(x_min), int(y_min) - int(1.3 * text_height)),
(int(x_min) + text_width, int(y_min)), BOX_COLOR, -1)
cv2.putText(
img,
text=class_name,
org=(int(x_min), int(y_min) - int(0.3 * text_height)),
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=0.35,
color=TEXT_COLOR,
lineType=cv2.LINE_AA,
)
return img
def visualize(image, bboxes, category_ids, category_id_to_name):
img = image.copy()
for bbox, category_id in zip(bboxes, category_ids):
class_name = category_id_to_name[category_id]
img = visualize_bbox(img, bbox, class_name)
plt.axis('off')
plt.imshow(img)
plt.show()
def saveNewAnnotation(new_xml_path, new_image_path, xml_path, bboxes, cur_dir):
in_file = open(os.path.join(xml_path), encoding='utf-8')
new_file = in_file
tree = ET.parse(new_file)
root = tree.getroot()
root[0].text = "images"
root[1].text = new_image_path
root[2].text = cur_dir + '\\images_aug\\' + new_image_path
idx = 0
for obj in root.iter('object'):
# 可能会出现越界问题,根据xml文件中obj位置确定
obj[OBJ_NUM][0].text = str(round(bboxes[idx][0]))
obj[OBJ_NUM][1].text = str(round(bboxes[idx][1]))
obj[OBJ_NUM][2].text = str(round(bboxes[idx][2]))
obj[OBJ_NUM][3].text = str(round(bboxes[idx][3]))
idx += 1
tree.write(new_xml_path, 'UTF-8')
def getAnnotation(xml_path):
'''
:param xml_path:
:return: bboxes, category_ids
'''
in_file = open(os.path.join(xml_path), encoding='utf-8')
try:
tree = ET.parse(in_file)
except:
return [], []
root = tree.getroot()
bboxes = []
category_ids = []
for obj in root.iter('object'):
cls = obj.find('name').text
xmlbox = obj.find('bndbox')
bbox = [int(float(xmlbox.find('xmin').text)),
int(float(xmlbox.find('ymin').text)),
int(float(xmlbox.find('xmax').text)),
int(float(xmlbox.find('ymax').text))]
bboxes.append(bbox)
category_ids.append(cls)
return bboxes, category_ids
def main(cur_dir):
images_path = os.path.join(cur_dir, IMAGES_FILE)
for image_name in os.listdir(images_path):
image_path = os.path.join(images_path, image_name)
xml_name = image_name.split('.')[0] + ".xml"
xmls_path = os.path.join(cur_dir, ANNOTATIONS_FILE)
xml_path = os.path.join(xmls_path, xml_name)
if os.path.exists(xml_path):
# print("the image is : " + image_path)
# print("the xml of image is : " + xml_path)
for i in range(GENERATED_PICS_SIZE):
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 进行名字更新,1.jpg 更新后 1_001.jpg ~ 1_020.jpg
new_image_name = image_name.split('.')[0] + "_" + str(i + 1).zfill(3) + ".jpg"
new_xml_name = xml_name.split('.')[0] + "_" + str(i + 1).zfill(3) + ".xml"
# 增强后的新图片存放的文件夹地址
new_images_path = os.path.join(cur_dir, "images_aug")
# 增强后的新图片png地址
new_image_path = os.path.join(new_images_path, new_image_name)
# 增强后的新标注存放的文件夹地址
new_xmls_path = os.path.join(cur_dir, "annotations_aug")
# 增强后的新标注xml文件地址
new_xml_path = os.path.join(new_xmls_path, new_xml_name)
# print("the new image is : " + new_image_path)
# print("the new xml of image is : " + new_xml_path)
bboxes, category_ids = getAnnotation(xml_path=xml_path)
if len(bboxes) == 0 & len(category_ids) == 0:
continue
category_id_to_name = {}
for i in range(len(category_ids)):
category_id_to_name[category_ids[i]] = category_ids[i]
# 变换操作
transform = A.Compose(
[
A.HorizontalFlip(p=0.5), # 水平翻转
A.VerticalFlip(p=0.5), # 垂直翻转
A.ColorJitter(brightness=0.05, contrast=0.05, # 改变图像的属性:亮度(brightness)、对比度(contrast)
saturation=0.02, # 饱和度(saturation)
hue=0.02, always_apply=False, p=1), # 色调(hue)
A.Sharpen(p=1) # 锐化,加强细节
],
bbox_params=A.BboxParams(format='pascal_voc',
label_fields=['category_ids']),
)
transformed = transform(image=image, bboxes=bboxes,
category_ids=category_ids)
image = transformed['image']
bboxes = transformed['bboxes']
category_ids = transformed['category_ids']
# print(bboxes)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# 保存图片
cv2.imencode('.jpg', image)[1].tofile(new_image_path)
# visualize(image, bboxes, category_ids, category_id_to_name)
saveNewAnnotation(new_xml_path, new_image_name, xml_path, bboxes, cur_dir)
print(new_image_name)
else:
with open(os.path.join(DIR, "no-annotations.txt"), 'a') as f:
print(("No this annotations, name of image : " + image_name) , file = f)
time.sleep(1)
if __name__ == '__main__':
main(DIR)