基于Albumentations的图像增强,对yolo数据,生成增强后的图片以及标签。

Albumentations是一个用于图像增强的Python库,它提供了多种增强技术,包括随机裁剪、旋转、缩放、翻转、变形、颜色变换、模糊等操作。使用Albumentations库可以快速、高效地对图像数据进行增强,从而提升机器学习模型的鲁棒性。

本人根据非常棒的Albumentations数据增强库进行二次封装,将yolo数据生成增强后的标签跟图片,代码更改路径可直接调用。

from albumentations import *
import os
import cv2
from tqdm import tqdm
 
 
class enhancement:
    def __init__(self, picture_path, label_path, save_img_path, save_lable_path):
        self.picture_name = sorted(os.listdir(picture_path))
        self.label_name = sorted(os.listdir(label_path))
        self.picture_path = [picture_path + i for i in self.picture_name]
        self.label_path = [label_path + i for i in self.label_name]
        self.save_img_path = save_img_path
        self.save_lable_path = save_lable_path
 
    def iter(self):
        batch_size = 10
        for index_bin in tqdm(range(0, len(self.picture_path), batch_size), desc='批次进度'):
            # print(index_bin)
            picture_batch = self.picture_path[index_bin:index_bin + batch_size]
            label_batch = self.label_path[index_bin:index_bin + batch_size]
            yield picture_batch, label_batch, [index_bin, index_bin + batch_size]
 
    def get_transform(self):
        '''
        这里修改需要图像增强的具体方法
        :return:
        '''
        transform = Compose([
            # 图像均值平滑滤波。
            Blur(blur_limit=7, always_apply=False, p=0.5),
            # VerticalFlip 水平翻转
            VerticalFlip(always_apply=False, p=0.5),
            # HorizontalFlip 垂直翻转
            HorizontalFlip(always_apply=False, p=1),
            # 中心裁剪
            CenterCrop(200, 200, always_apply=False, p=1.0),
            # RandomFog(fog_coef_lower=0.3, fog_coef_upper=0.7, alpha_coef=0.08, always_apply=False, p=1),
            # RandomCrop(width=200, height=200)
            # 添加其他增强技术
        ], bbox_params=BboxParams(format='yolo', label_fields=['class_labels']))
        return transform
 
    def augmentations(self, image, bboxes, class_labels):
        transform = self.get_transform()
        transformed = transform(image=image, bboxes=bboxes, class_labels=class_labels)
        augmented_image = transformed['image']
        augmented_bboxes = transformed['bboxes']
        augmented_labels = transformed['class_labels']
        return augmented_image, augmented_bboxes, augmented_labels
 
    def augmented_image_bboxes(self, img_path, l_path):
        with open(l_path, 'r') as f:
            values = f.read()
            f.close()
        class_labels, original_bboxes = [], []
        values = [i.split(' ') for i in values.split('\n')[:-1]]
        for i in values:
            class_labels.append(int(i[0]))
            original_bboxes.append([float(i) for i in i[1:]])
        original_image = cv2.imread(img_path)
        augmented_image, augmented_bboxes, augmented_labels = self.augmentations(original_image, original_bboxes,
                                                                                 class_labels)
        return augmented_image, augmented_bboxes, augmented_labels, original_image
 
    def parsing_data(self, p_l_i):
        img_path, l_path, index = p_l_i[0], p_l_i[1], p_l_i[2]
        self.augmented_image, self.augmented_bboxes, augmented_labels, original_image = self.augmented_image_bboxes(
            img_path, l_path)
        data = []
        for l, d in zip(augmented_labels, self.augmented_bboxes):
            s = ' '.join(map(str, [l] + list(d)))
            data.append(s)
        data = '\n'.join(data)
        if augmented_labels:
            self.show_img()
            self.save_img_lable(data, self.augmented_image, self.save_img_path, self.save_lable_path, index)
        else:
            print(f'{self.picture_name[index]}该图片没有标签,不做保存')
 
    def save_img_lable(self, data, img, save_img_path, save_lable_path, index):
        cv2.imwrite(save_img_path + self.picture_name[index], img)
        with open(save_lable_path + self.label_name[index], 'w') as f:
            f.write(data)
            f.close()
 
    def __call__(self):
        for picture_batch, label_batch, index_bin in self.iter():
            list(map(self.parsing_data,
                     [(p, l, i) for p, l, i in zip(picture_batch, label_batch, range(index_bin[0], index_bin[1]))]))
 
    def show_img(self, boxe=True):
        '''
        boxe = True,则保存的图片会有标签框
        '''
        if boxe:
            for j in self.augmented_bboxes:
                x, y, w, h = j
                x1 = int((x - w / 2) * self.augmented_image.shape[1])
                y1 = int((y - h / 2) * self.augmented_image.shape[0])
                x2 = int((x + w / 2) * self.augmented_image.shape[1])
                y2 = int((y + h / 2) * self.augmented_image.shape[0])
                cv2.rectangle(self.augmented_image, (x1, y1), (x2, y2), (255, 0, 0), 2)
                cv2.rectangle(self.augmented_image, (x1, y1), (x2, y2), (255, 0, 0), 2)
        else:
            pass
        # cv2.imshow('Augmented Image', self.augmented_image)
        # cv2.waitKey(0)
        # cv2.destroyAllWindows()
 
 
if __name__ == '__main__':
    # 原图片,标签的路径
    picture_path = 'D:\\mydemo\\mask_data\\MaskDataset\\images\\train\\'
    label_path = 'D:\\mydemo\\mask_data\\MaskDataset\\labels\\train\\'
    # 增强后的图片跟标签
    save_img_path = 'D:\\mydemo\\mask_data\\MaskDataset\\images\\dd\\'
    save_lable_path = 'D:\\mydemo\mask_data\\MaskDataset\\images\\dd\\'
    c = enhancement(picture_path=picture_path,
                    label_path=label_path,
                    save_img_path=save_img_path,
                    save_lable_path=save_lable_path)
    c()

效果如下:

增强前后的标签以及图片变化:
在这里插入图片描述在这里插入图片描述
在这里插入图片描述

原文链接基于Albumentations的图像增强,对yolo数据,生成增强后的图片以及标签。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
如果你已经有了原始的标注数据,那么你可以直接使用 Albumentations 库提供的 BboxParams 将标注数据转换为适用于 YOLO 模型的格式。下面是一个示例代码: ``` import albumentations as A from albumentations.augmentations.bbox_utils import convert_bbox_to_albumentations # 假设原始标注数据格式为 [x_min, y_min, x_max, y_max, class_id] bboxes = [[100, 100, 200, 200, 0], [300, 300, 400, 400, 1], ...] # 将 bboxes 转换为适用于 Albumentations 的格式 # 注意,这里的标注数据是针对原始图像的,还没有进行裁剪、缩放等操作 transformed_bboxes = [convert_bbox_to_albumentations(bbox, 'pascal_voc', (height, width)) for bbox in bboxes] # 定义 transform 对象,包括图像增强方法和标注数据转换方法 transform = A.Compose([ # 图像增强方法... ], bbox_params=A.BboxParams(format='yolo', label_fields=['category_ids', 'bboxes'])) # 将 transform 应用于数据集 dataset = YourDataset(...) dataset.transforms = transform dataset.bboxes = transformed_bboxes ``` 在上面的代码中,我们首先将原始标注数据转换为 Albumentations 支持的格式,然后定义了一个 transform 对象,其中 bbox_params 参数指定了标注数据的格式为 YOLO 格式,同时也指定了标注数据中类别 id 和边界框坐标的字段名称。最后将 transform 应用于数据集,并将转换后的标注数据赋值给数据集的 bboxes 属性。注意,由于 Albumentations 库对图像进行了裁剪、缩放等操作,因此需要在标注数据转换函数中将标注数据也进行相应的裁剪、缩放等操作。具体的实现方法可以参考 Albumentations 库的文档。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值