yolov5训练自定义数据集使用Albumentations库做数据增强,得到增强后图片与txt标注文件

任务:识别自定义数据集中的目标物品(不包含在yolov5预训练模型的可识别类型中,需要自己训练模型),训练部分的代码已有很多教程,本文不再重复。

本文着重写针对自定义数据很少但又需要训练的情况,如何实现数据增强及获得增强后图片相应的txt标注文件。

解决方案的代码分三步:

1.使用Albumentations库做图像增强并生成相应的xml标注文件

2.标注文件xml转txt

3.将图片和标签划分成可用于训练的train,test,val

(三部分代码可放在一个文件中一次性执行)

先导入下列代码需要的所有包

import random
import time
import os
import cv2
import shutil
import numpy as np
import albumentations as A
import xml.etree.ElementTree as ET
1.使用Albumentations库做图像增强并生成相应的xml标注文件

此处代码参考了下列文章,并根据个人需求做了一些修改,文章中是对每张已有的图只生成一张增强图片,且增强策略是可选择多个策略叠加。

yolo数据增强以及批量修改图片和xml名_yolo的xml文件名要求-CSDN博客

而本文实现的是针对少量数据生成多张增强后的图像,可在代码中设置想要增强的张数(num_augmentations),且增强策略是每次增强只选择一项,此处可根据个人需要修改

代码中需要自行设置的部分为文件路径(绝对路径),设置完之后可以直接运行

#1.图像增强并生成相应的xml标注文件
class VOCAug(object):

    def __init__(self,
                 pre_image_path=None,
                 pre_xml_path=None,
                 aug_image_save_path=None,
                 aug_xml_save_path=None,
                 start_aug_id=None,
                 labels=None,
                 max_len=3,  # 修改数值可以改变名字 1-1, 2-01, 3-001, 4-0001
                 is_show=False):
        """
        :param pre_image_path:
        :param pre_xml_path:
        :param aug_image_save_path:
        :param aug_xml_save_path:
        :param start_aug_id:
        :param labels: 标签列表, 展示增强后的图片用
        :param max_len:
        :param is_show:
        """
        self.pre_image_path = pre_image_path
        self.pre_xml_path = pre_xml_path
        self.aug_image_save_path = aug_image_save_path
        self.aug_xml_save_path = aug_xml_save_path
        self.start_aug_id = start_aug_id
        self.labels = labels
        self.max_len = max_len
        self.is_show = is_show

        print(self.labels)
        assert self.labels is not None, "labels is None!!!"


        print('--------------*--------------')
        print("labels: ", self.labels)
        if self.start_aug_id is None:
            self.start_aug_id = len(os.listdir(self.pre_xml_path)) + 1
            print("the start_aug_id is not set, default: len(images)", self.start_aug_id)
        print('--------------*--------------')

    def get_xml_data(self, xml_filename):
        with open(os.path.join(self.pre_xml_path, xml_filename), 'r') as f:
            tree = ET.parse(f)
            root = tree.getroot()
            image_name = tree.find('filename').text
            size = root.find('size')
            w = int(size.find('width').text)
            h = int(size.find('height').text)
            bboxes = []
            cls_id_list = []
            for obj in root.iter('object'):
                # difficult = obj.find('difficult').text
                difficult = obj.find('difficult').text
                cls_name = obj.find('name').text  # label
                if cls_name not in LABELS or int(difficult) == 1:
                    continue
                xml_box = obj.find('bndbox')

                xmin = int(xml_box.find('xmin').text)
                ymin = int(xml_box.find('ymin').text)
                xmax = int(xml_box.find('xmax').text)
                ymax = int(xml_box.find('ymax').text)

                # 标注越界修正
                if xmax > w:
                    xmax = w
                if ymax > h:
                    ymax = h
                bbox = [xmin, ymin, xmax, ymax]
                bboxes.append(bbox)
                cls_id_list.append(self.labels.index(cls_name))

            # 读取图片
            image = cv2.imread(os.path.join(self.pre_image_path, image_name))

        return bboxes, cls_id_list, image, image_name

    def aug_image(self, num_augmentations=1000):
        xml_list = os.listdir(self.pre_xml_path)

        cnt = self.start_aug_id
        for xml in xml_list:
            if xml.split('.')[-1] != 'xml':
                continue
            bboxes, cls_id_list, image, image_name = self.get_xml_data(xml)
            for _ in range(num_augmentations):
                # 每次循环随机选择一组增强方法
                random_augmentations = [
                    A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1),
                    A.RandomRotate90(p=1),
                    A.GaussianBlur(p=1), # 高斯模糊
                    A.GaussNoise(var_limit=(400, 450),mean=0,p=1),  # 高斯噪声
                    A.Rotate(limit=45, interpolation=0, border_mode=0, p=1),
                    A.Rotate(limit=30, interpolation=0, border_mode=0, p=1),
                    A.Rotate(limit=75, interpolation=0, border_mode=0, p=1),
                    A.Rotate(limit=120, interpolation=0, border_mode=0, p=1),
                    A.RGBShift(r_shift_limit=50, g_shift_limit=50, b_shift_limit=50, p=1),
                    A.ColorJitter(p=1),  # 随机改变图像的亮度、对比度、饱和度、色调
                    A.Downscale(p=1),  # 随机缩小和放大来降低图像质量
                    #A.RandomCrop(width=256, height=256,p=1.0),
                    # A.Emboss(p=0.2),  # 压印输入图像并将结果与原始图像叠加
                    # A.CLAHE(clip_limit=2.0, tile_grid_size=(4, 4), p=0.8),  # 直方图均衡
                    # A.Equalize(p=0.8),  # 均衡图像直方图
                    # A.ChannelShuffle(p=0.3),# 随机排列通道
                    # ... 其他增强方法 ...
                ]
                # 随机选择增强方法
                selected_augmentations = random.sample(random_augmentations, k=1)#k=min(len(random_augmentations), num_augmentations)
                # 创建增强策略
                self.aug = A.Compose(selected_augmentations, bbox_params=A.BboxParams(format='pascal_voc', min_area=0., min_visibility=0., label_fields=['category_id']))

                anno_dict = {'image': image, 'bboxes': bboxes, 'category_id': cls_id_list}
                augmented = self.aug(**anno_dict)

                # 保存增强后的数据
                flag = self.save_aug_data(augmented, image_name, cnt)

                if flag:
                    cnt += 1
                else:
                    break  # 如果保存失败,则跳出循环

    def save_aug_data(self, augmented, image_name, cnt):
        aug_image = augmented['image']
        aug_bboxes = augmented['bboxes']
        aug_category_id = augmented['category_id']

        # 使用时间戳和增强次数生成唯一的文件名
        timestamp = int(time.time())
        new_image_name = f"{image_name.split('.')[0]}_{cnt}_{timestamp}.{image_name.split('.')[1]}"
        new_xml_name = new_image_name.replace('.' + image_name.split('.')[1], '.xml')

        # 保存增强后的图片
        cv2.imwrite(os.path.join(self.aug_image_save_path, new_image_name), aug_image)

        # 构建对应的XML文件名
        # 假设原始图像文件名和XML文件名具有相同的基本名称
        original_xml_name = image_name.replace('.' + image_name.split('.')[1], '.xml')
        full_path = os.path.join(self.pre_xml_path, original_xml_name)
        with open(full_path, 'r') as pre_xml:
            aug_tree = ET.parse(pre_xml)

        # 修改image_filename值
        root = aug_tree.getroot()
        aug_tree.find('filename').text = new_image_name

        # 修改每一个标注框
        for index, obj in enumerate(root.iter('object')):
            #print("The length of aug_category_id list is:",len(aug_category_id))
            obj.find('name').text = self.labels[aug_category_id[index]]
            xmin, ymin, xmax, ymax = aug_bboxes[index]
            xml_box = obj.find('bndbox')
            xml_box.find('xmin').text = str(int(xmin))
            xml_box.find('ymin').text = str(int(ymin))
            xml_box.find('xmax').text = str(int(xmax))
            xml_box.find('ymax').text = str(int(ymax))

        # 保存增强后的xml文件
        tree = ET.ElementTree(root)
        tree.write(os.path.join(self.aug_xml_save_path, new_xml_name))

        return True


# 原始的xml路径和图片路径
PRE_IMAGE_PATH = '/...'
PRE_XML_PATH = '/...'

# 增强后保存的xml路径和图片路径
AUG_SAVE_IMAGE_PATH = '/...'
AUG_SAVE_XML_PATH = '/...'

# 标签列表
LABELS = ['标签名']

aug = VOCAug(
    pre_image_path=PRE_IMAGE_PATH,
    pre_xml_path=PRE_XML_PATH,
    aug_image_save_path=AUG_SAVE_IMAGE_PATH,
    aug_xml_save_path=AUG_SAVE_XML_PATH,
    start_aug_id=None,
    labels=LABELS,
    is_show=False,
)

aug.aug_image()
2.标注文件xml转txt

代码中需要自行设置的部分为文件路径(绝对路径)

#2.标注文件xml转txt
def convert(size, box):
    x_center = (box[0] + box[1]) / 2.0
    y_center = (box[2] + box[3]) / 2.0
    x = x_center / size[0]
    y = y_center / size[1]
    w = (box[1] - box[0]) / size[0]
    h = (box[3] - box[2]) / size[1]
    return (x, y, w, h)

def convert_annotation(xml_files_path, save_txt_files_path, classes):
    xml_files = os.listdir(xml_files_path)
    print(xml_files)
    for xml_name in xml_files:
        print(xml_name)
        xml_file = os.path.join(xml_files_path, xml_name)
        out_txt_path = os.path.join(save_txt_files_path, xml_name.split('.')[0] + '.txt')
        out_txt_f = open(out_txt_path, 'w')
        tree = ET.parse(xml_file)
        root = tree.getroot()
        size = root.find('size')
        w = int(size.find('width').text)
        h = int(size.find('height').text)

        for obj in root.iter('object'):
            difficult = obj.find('difficult').text
            cls = obj.find('name').text
            if cls not in classes or int(difficult) == 1:
                continue
            cls_id = classes.index(cls)
            xmlbox = obj.find('bndbox')
            b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
                 float(xmlbox.find('ymax').text))
            # b=(xmin, xmax, ymin, ymax)
            print(w, h, b)
            bb = convert((w, h), b)
            out_txt_f.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')


if __name__ == "__main__":
    # 需要转换的类别,需要一一对应
    classes1 = ['标签']
    # 2、voc格式的xml标签文件路径
    xml_files1 = AUG_SAVE_XML_PATH
    # 3、转化为yolo格式的txt标签文件存储路径
    save_txt_files1 = '/...'
    convert_annotation(xml_files1, save_txt_files1, classes1)

#现在得到增强后的图片路径:AUG_SAVE_IMAGE_PATH = '/...'
#和增强后图片的txt标签路径:save_txt_files1 = '/...'
3.将图片和标签划分成可用于训练的train,test,val

代码中需要自行设置的部分为文件路径(绝对路径)

#3.将图片和标签划分成可用于训练的train,test,val
val_size = 0.1
test_size = 0.1
postfix = 'jpg'
#存放图像文件的路径
imgpath = '/...'
#存放txt标注文件的路径
txtpath = '/...'

os.makedirs('images/train', exist_ok=True)
os.makedirs('images/val', exist_ok=True)
os.makedirs('images/test', exist_ok=True)
os.makedirs('labels/train', exist_ok=True)
os.makedirs('labels/val', exist_ok=True)
os.makedirs('labels/test', exist_ok=True)

listdir = np.array([i for i in os.listdir(txtpath) if 'txt' in i])
np.random.shuffle(listdir)
train, val, test = listdir[:int(len(listdir) * (1 - val_size - test_size))], listdir[int(len(listdir) * (
            1 - val_size - test_size)):int(len(listdir) * (1 - test_size))], listdir[int(len(listdir) * (1 - test_size)):]
print(f'train set size:{len(train)} val set size:{len(val)} test set size:{len(test)}')

for i in train:
    shutil.copy('{}/{}.{}'.format(imgpath, i[:-4], postfix), 'images/train/{}.{}'.format(i[:-4], postfix))
    shutil.copy('{}/{}'.format(txtpath, i), 'labels/train/{}'.format(i))

for i in val:
    shutil.copy('{}/{}.{}'.format(imgpath, i[:-4], postfix), 'images/val/{}.{}'.format(i[:-4], postfix))
    shutil.copy('{}/{}'.format(txtpath, i), 'labels/val/{}'.format(i))

for i in test:
    shutil.copy('{}/{}.{}'.format(imgpath, i[:-4], postfix), 'images/test/{}.{}'.format(i[:-4], postfix))
    shutil.copy('{}/{}'.format(txtpath, i), 'labels/test/{}'.format(i))

以上就是实现图像增强并获得相应标注文件的代码

楼主在运行时遇到一个问题:

A.RandomCrop(width=256, height=256,p=1.0)

 在选择随机裁剪作为增强策略时,代码出现报错:

Traceback (most recent call last):
  File "draft.py", line 181, in <module>
    aug.aug_image()
  File "draft.py", line 112, in aug_image
    flag = self.save_aug_data(augmented, image_name, cnt)
  File "draft.py", line 145, in save_aug_data
    obj.find('name').text = self.labels[aug_category_id[index]]
IndexError: list index out of range

楼主认为是随机裁剪使得图像中目标的标注可能为0,导致的错误,如果有遇到相同问题或是已经解决的朋友欢迎在评论区交流~

  • 7
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值