任务:识别自定义数据集中的目标物品(不包含在yolov5预训练模型的可识别类型中,需要自己训练模型),训练部分的代码已有很多教程,本文不再重复。
本文着重写针对自定义数据很少但又需要训练的情况,如何实现数据增强及获得增强后图片相应的txt标注文件。
解决方案的代码分三步:
1.使用Albumentations库做图像增强并生成相应的xml标注文件
2.标注文件xml转txt
3.将图片和标签划分成可用于训练的train,test,val
(三部分代码可放在一个文件中一次性执行)
先导入下列代码需要的所有包
import random
import time
import os
import cv2
import shutil
import numpy as np
import albumentations as A
import xml.etree.ElementTree as ET
1.使用Albumentations库做图像增强并生成相应的xml标注文件
此处代码参考了下列文章,并根据个人需求做了一些修改,文章中是对每张已有的图只生成一张增强图片,且增强策略是可选择多个策略叠加。
yolo数据增强以及批量修改图片和xml名_yolo的xml文件名要求-CSDN博客
而本文实现的是针对少量数据生成多张增强后的图像,可在代码中设置想要增强的张数(num_augmentations),且增强策略是每次增强只选择一项,此处可根据个人需要修改
代码中需要自行设置的部分为文件路径(绝对路径),设置完之后可以直接运行
#1.图像增强并生成相应的xml标注文件
class VOCAug(object):
def __init__(self,
pre_image_path=None,
pre_xml_path=None,
aug_image_save_path=None,
aug_xml_save_path=None,
start_aug_id=None,
labels=None,
max_len=3, # 修改数值可以改变名字 1-1, 2-01, 3-001, 4-0001
is_show=False):
"""
:param pre_image_path:
:param pre_xml_path:
:param aug_image_save_path:
:param aug_xml_save_path:
:param start_aug_id:
:param labels: 标签列表, 展示增强后的图片用
:param max_len:
:param is_show:
"""
self.pre_image_path = pre_image_path
self.pre_xml_path = pre_xml_path
self.aug_image_save_path = aug_image_save_path
self.aug_xml_save_path = aug_xml_save_path
self.start_aug_id = start_aug_id
self.labels = labels
self.max_len = max_len
self.is_show = is_show
print(self.labels)
assert self.labels is not None, "labels is None!!!"
print('--------------*--------------')
print("labels: ", self.labels)
if self.start_aug_id is None:
self.start_aug_id = len(os.listdir(self.pre_xml_path)) + 1
print("the start_aug_id is not set, default: len(images)", self.start_aug_id)
print('--------------*--------------')
def get_xml_data(self, xml_filename):
with open(os.path.join(self.pre_xml_path, xml_filename), 'r') as f:
tree = ET.parse(f)
root = tree.getroot()
image_name = tree.find('filename').text
size = root.find('size')
w = int(size.find('width').text)
h = int(size.find('height').text)
bboxes = []
cls_id_list = []
for obj in root.iter('object'):
# difficult = obj.find('difficult').text
difficult = obj.find('difficult').text
cls_name = obj.find('name').text # label
if cls_name not in LABELS or int(difficult) == 1:
continue
xml_box = obj.find('bndbox')
xmin = int(xml_box.find('xmin').text)
ymin = int(xml_box.find('ymin').text)
xmax = int(xml_box.find('xmax').text)
ymax = int(xml_box.find('ymax').text)
# 标注越界修正
if xmax > w:
xmax = w
if ymax > h:
ymax = h
bbox = [xmin, ymin, xmax, ymax]
bboxes.append(bbox)
cls_id_list.append(self.labels.index(cls_name))
# 读取图片
image = cv2.imread(os.path.join(self.pre_image_path, image_name))
return bboxes, cls_id_list, image, image_name
def aug_image(self, num_augmentations=1000):
xml_list = os.listdir(self.pre_xml_path)
cnt = self.start_aug_id
for xml in xml_list:
if xml.split('.')[-1] != 'xml':
continue
bboxes, cls_id_list, image, image_name = self.get_xml_data(xml)
for _ in range(num_augmentations):
# 每次循环随机选择一组增强方法
random_augmentations = [
A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1),
A.RandomRotate90(p=1),
A.GaussianBlur(p=1), # 高斯模糊
A.GaussNoise(var_limit=(400, 450),mean=0,p=1), # 高斯噪声
A.Rotate(limit=45, interpolation=0, border_mode=0, p=1),
A.Rotate(limit=30, interpolation=0, border_mode=0, p=1),
A.Rotate(limit=75, interpolation=0, border_mode=0, p=1),
A.Rotate(limit=120, interpolation=0, border_mode=0, p=1),
A.RGBShift(r_shift_limit=50, g_shift_limit=50, b_shift_limit=50, p=1),
A.ColorJitter(p=1), # 随机改变图像的亮度、对比度、饱和度、色调
A.Downscale(p=1), # 随机缩小和放大来降低图像质量
#A.RandomCrop(width=256, height=256,p=1.0),
# A.Emboss(p=0.2), # 压印输入图像并将结果与原始图像叠加
# A.CLAHE(clip_limit=2.0, tile_grid_size=(4, 4), p=0.8), # 直方图均衡
# A.Equalize(p=0.8), # 均衡图像直方图
# A.ChannelShuffle(p=0.3),# 随机排列通道
# ... 其他增强方法 ...
]
# 随机选择增强方法
selected_augmentations = random.sample(random_augmentations, k=1)#k=min(len(random_augmentations), num_augmentations)
# 创建增强策略
self.aug = A.Compose(selected_augmentations, bbox_params=A.BboxParams(format='pascal_voc', min_area=0., min_visibility=0., label_fields=['category_id']))
anno_dict = {'image': image, 'bboxes': bboxes, 'category_id': cls_id_list}
augmented = self.aug(**anno_dict)
# 保存增强后的数据
flag = self.save_aug_data(augmented, image_name, cnt)
if flag:
cnt += 1
else:
break # 如果保存失败,则跳出循环
def save_aug_data(self, augmented, image_name, cnt):
aug_image = augmented['image']
aug_bboxes = augmented['bboxes']
aug_category_id = augmented['category_id']
# 使用时间戳和增强次数生成唯一的文件名
timestamp = int(time.time())
new_image_name = f"{image_name.split('.')[0]}_{cnt}_{timestamp}.{image_name.split('.')[1]}"
new_xml_name = new_image_name.replace('.' + image_name.split('.')[1], '.xml')
# 保存增强后的图片
cv2.imwrite(os.path.join(self.aug_image_save_path, new_image_name), aug_image)
# 构建对应的XML文件名
# 假设原始图像文件名和XML文件名具有相同的基本名称
original_xml_name = image_name.replace('.' + image_name.split('.')[1], '.xml')
full_path = os.path.join(self.pre_xml_path, original_xml_name)
with open(full_path, 'r') as pre_xml:
aug_tree = ET.parse(pre_xml)
# 修改image_filename值
root = aug_tree.getroot()
aug_tree.find('filename').text = new_image_name
# 修改每一个标注框
for index, obj in enumerate(root.iter('object')):
#print("The length of aug_category_id list is:",len(aug_category_id))
obj.find('name').text = self.labels[aug_category_id[index]]
xmin, ymin, xmax, ymax = aug_bboxes[index]
xml_box = obj.find('bndbox')
xml_box.find('xmin').text = str(int(xmin))
xml_box.find('ymin').text = str(int(ymin))
xml_box.find('xmax').text = str(int(xmax))
xml_box.find('ymax').text = str(int(ymax))
# 保存增强后的xml文件
tree = ET.ElementTree(root)
tree.write(os.path.join(self.aug_xml_save_path, new_xml_name))
return True
# 原始的xml路径和图片路径
PRE_IMAGE_PATH = '/...'
PRE_XML_PATH = '/...'
# 增强后保存的xml路径和图片路径
AUG_SAVE_IMAGE_PATH = '/...'
AUG_SAVE_XML_PATH = '/...'
# 标签列表
LABELS = ['标签名']
aug = VOCAug(
pre_image_path=PRE_IMAGE_PATH,
pre_xml_path=PRE_XML_PATH,
aug_image_save_path=AUG_SAVE_IMAGE_PATH,
aug_xml_save_path=AUG_SAVE_XML_PATH,
start_aug_id=None,
labels=LABELS,
is_show=False,
)
aug.aug_image()
2.标注文件xml转txt
代码中需要自行设置的部分为文件路径(绝对路径)
#2.标注文件xml转txt
def convert(size, box):
x_center = (box[0] + box[1]) / 2.0
y_center = (box[2] + box[3]) / 2.0
x = x_center / size[0]
y = y_center / size[1]
w = (box[1] - box[0]) / size[0]
h = (box[3] - box[2]) / size[1]
return (x, y, w, h)
def convert_annotation(xml_files_path, save_txt_files_path, classes):
xml_files = os.listdir(xml_files_path)
print(xml_files)
for xml_name in xml_files:
print(xml_name)
xml_file = os.path.join(xml_files_path, xml_name)
out_txt_path = os.path.join(save_txt_files_path, xml_name.split('.')[0] + '.txt')
out_txt_f = open(out_txt_path, 'w')
tree = ET.parse(xml_file)
root = tree.getroot()
size = root.find('size')
w = int(size.find('width').text)
h = int(size.find('height').text)
for obj in root.iter('object'):
difficult = obj.find('difficult').text
cls = obj.find('name').text
if cls not in classes or int(difficult) == 1:
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
float(xmlbox.find('ymax').text))
# b=(xmin, xmax, ymin, ymax)
print(w, h, b)
bb = convert((w, h), b)
out_txt_f.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
if __name__ == "__main__":
# 需要转换的类别,需要一一对应
classes1 = ['标签']
# 2、voc格式的xml标签文件路径
xml_files1 = AUG_SAVE_XML_PATH
# 3、转化为yolo格式的txt标签文件存储路径
save_txt_files1 = '/...'
convert_annotation(xml_files1, save_txt_files1, classes1)
#现在得到增强后的图片路径:AUG_SAVE_IMAGE_PATH = '/...'
#和增强后图片的txt标签路径:save_txt_files1 = '/...'
3.将图片和标签划分成可用于训练的train,test,val
代码中需要自行设置的部分为文件路径(绝对路径)
#3.将图片和标签划分成可用于训练的train,test,val
val_size = 0.1
test_size = 0.1
postfix = 'jpg'
#存放图像文件的路径
imgpath = '/...'
#存放txt标注文件的路径
txtpath = '/...'
os.makedirs('images/train', exist_ok=True)
os.makedirs('images/val', exist_ok=True)
os.makedirs('images/test', exist_ok=True)
os.makedirs('labels/train', exist_ok=True)
os.makedirs('labels/val', exist_ok=True)
os.makedirs('labels/test', exist_ok=True)
listdir = np.array([i for i in os.listdir(txtpath) if 'txt' in i])
np.random.shuffle(listdir)
train, val, test = listdir[:int(len(listdir) * (1 - val_size - test_size))], listdir[int(len(listdir) * (
1 - val_size - test_size)):int(len(listdir) * (1 - test_size))], listdir[int(len(listdir) * (1 - test_size)):]
print(f'train set size:{len(train)} val set size:{len(val)} test set size:{len(test)}')
for i in train:
shutil.copy('{}/{}.{}'.format(imgpath, i[:-4], postfix), 'images/train/{}.{}'.format(i[:-4], postfix))
shutil.copy('{}/{}'.format(txtpath, i), 'labels/train/{}'.format(i))
for i in val:
shutil.copy('{}/{}.{}'.format(imgpath, i[:-4], postfix), 'images/val/{}.{}'.format(i[:-4], postfix))
shutil.copy('{}/{}'.format(txtpath, i), 'labels/val/{}'.format(i))
for i in test:
shutil.copy('{}/{}.{}'.format(imgpath, i[:-4], postfix), 'images/test/{}.{}'.format(i[:-4], postfix))
shutil.copy('{}/{}'.format(txtpath, i), 'labels/test/{}'.format(i))
以上就是实现图像增强并获得相应标注文件的代码
楼主在运行时遇到一个问题:
A.RandomCrop(width=256, height=256,p=1.0)
在选择随机裁剪作为增强策略时,代码出现报错:
Traceback (most recent call last):
File "draft.py", line 181, in <module>
aug.aug_image()
File "draft.py", line 112, in aug_image
flag = self.save_aug_data(augmented, image_name, cnt)
File "draft.py", line 145, in save_aug_data
obj.find('name').text = self.labels[aug_category_id[index]]
IndexError: list index out of range
楼主认为是随机裁剪使得图像中目标的标注可能为0,导致的错误,如果有遇到相同问题或是已经解决的朋友欢迎在评论区交流~