CV计算机视觉的数据增强方法:使用albumentations库批量数据增强,同时实现了坐标变换

Albumentations 是一个流行的 Python 库,专为图像增强任务设计,以帮助提高机器学习和深度学习模型的性能。该库高效且易于使用,支持广泛的增强技术,旨在提供快速且多样化的数据增强方法。

安装 Albumentations

要安装 Albumentations,您可以使用 pip,这是 Python 的包管理工具。在命令行中运行以下命令即可安装:

bash

pip install albumentations

确保您的 Python 环境已激活(如果使用虚拟环境),这样安装的库才会放在正确的环境中。

数据增强方法

Albumentations 提供了多种数据增强方法,包括但不限于:

  1. 几何变换:如旋转(Rotate)、翻转(Flip)、缩放(Scale)、裁剪(Crop)等,这些操作可以改变图像的空间结构。
  2. 颜色变换:如调整亮度(Brightness)、对比度(Contrast)、饱和度(Saturation),以及更复杂的操作如随机对比度(RandomBrightnessContrast)和色调变换(HueSaturationValue)。
  3. 噪声注入:如高斯噪声(GaussianNoise)、盐和胡椒噪声(SaltAndPepper)等,这些增强可以帮助模型学习在噪声存在的情况下进行更鲁棒的预测。
  4. 遮挡和遮蔽:如遮挡部分图像区域(CoarseDropout),用于模拟丢失像素的场景。
  5. 模糊和锐化:如高斯模糊(GaussianBlur)、运动模糊(MotionBlur)和锐化(Sharpen),用于模拟摄影中的常见效果。

使用 Albumentations 进行数据增强不仅可以扩展训练数据集,还可以帮助提高模型对新、未见过的图像数据的泛化能力。它的设计允许轻松集成到现有的数据处理流程中,并支持与其他流行的深度学习库如 PyTorch 和 TensorFlow 的无缝配合。

示例代码

下面是一个使用 Albumentations 进行图像增强的简单示例,同时实现了坐标变换。(可以使用标签可视化工具脚本检查,见我的另外一个帖子:https://mp.csdn.net/mp_blog/creation/editor/140671100。)

import albumentations as A from PIL import Image import numpy as np # 加载图像 image = np.array(Image.open('path_to_image.jpg')) # 定义一个增强管道

transform = A.Compose([ A.RandomCrop(width=256, height=256), A.HorizontalFlip(p=0.5), A.RandomBrightnessContrast(p=0.2), ]) # 应用增强

augmented_image = transform(image=image)['image']

这个例子中,我们定义了一个变换管道,包括随机裁剪、水平翻转和随机亮度对比度调整,然后将其应用到一张图像上。这种方法可以在训练神经网络时用来动态创建增强的图像数据。


以下是批量处理的python脚本思路:

代码思路:

  1. 初始化:在YOLOAug类的构造函数中,接收输入参数,包括原始图像和标签的路径、保存增强后图像和标签的路径、类别标签列表、需要增强的类别以及每个类别的目标计数。
  2. 数据增强:使用albumentations库定义了一系列的数据增强操作,包括像素级变换和空间级变换。这些操作可以增加图像的多样性,提高模型的泛化能力。
  3. 遍历处理:在aug_image方法中,遍历原始图像路径下的所有图像文件,对每张图像应用定义好的数据增强操作。
  4. 保存结果:对每张增强后的图像和对应的标签进行保存,使用uuid生成唯一的文件名以区分原始图像和增强后的图像。
  5. 执行:在main函数中创建YOLOAug类的实例,并调用aug_image方法来执行整个数据增强流程。

这段代码是一个用于数据增强的Python脚本,主要针对目标检测任务中的图像和标签文件进行增强。使用的数据增强库是albumentations,这是一个非常强大的图像增强库,支持各种像素级和空间级的变换。

读取的数据集目标的格式:
源目录:
datasetorg:绝对路径:"C:\Desktop\dataset_org"
--images:绝对路径:"C:\Desktop\dataset_org\images"
----.jpg
--labels:绝对路径:"C:\Desktop\dataset_org\labels"
.....txt
目标目录:
datasetdst:绝对路径:"C:\Desktop\dataset_dst"
--images:绝对路径:"C:\Desktop\dataset_dst\images"
----.jpg
--labels:绝对路径:"C:\Desktop\dataset_dst\labels"
.......txt

.txt文件的标签格式yolo格式:
3 0.385417 0.498148 0.0625 0.033333

数据集是这样子的,同时给出了标签的格式。labels目录中的.txt中的数据是目标检测的标签数据,每一行数据分别对应一个GT框,每一行数据从左到右分别是:id, x,y,w,h。列1 - 目标类别id ,  列2 - 目标中心位置x, 列3 - 目标中心位置y, 列4 - 目标宽度w,列5 - 目标高度h。x,y,w,h是小于1的浮点数,因为是经过对图像进行了归一化处理得到的值,也就是目标的真实的x,w值除以图像的宽度,y,h除以图像的高度。

以下代码基本实现了读取数据集使用albumentation数据增强,且可指定类别,具体使用哪些数据增强的方法还需要根据数据集具体分析且修改,代码中只是给出几个例子,没有展示每一个数据增强的方法。

import os
import cv2
import albumentations as A
from tqdm import tqdm
import uuid


class YOLOAug:
    def __init__(self, pre_image_path, pre_label_path, aug_save_image_path, aug_save_label_path, labels, classes_to_augment, target_counts):
        """
        初始化YOLOAug类的实例。
        :param pre_image_path: 原始图像文件夹路径。
        :param pre_label_path: 原始标签文件夹路径。
        :param aug_save_image_path: 增强后图像保存路径。
        :param aug_save_label_path: 增强后标签保存路径。
        :param labels: 所有类别的标签列表。
        :param classes_to_augment: 需要进行数据增强的类别列表。
        :param target_counts: 每个类别增强后的目标数量。
        """
        self.pre_image_path = pre_image_path
        self.pre_label_path = pre_label_path
        self.aug_save_image_path = aug_save_image_path
        self.aug_save_label_path = aug_save_label_path
        self.labels = labels
        self.classes_to_augment = classes_to_augment
        self.target_counts = target_counts
        # 定义数据增强的组合,包括像素级变换和空间级变换
        self.aug = A.Compose([
            # #1、Pixel-level transforms(包含了颜色变换、噪声和模糊 )
            # AdvancedBlur  # (使用随机选择参数的广义正态滤波器)
            # Blur  # (模糊)
            A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.05),  # CLAHE  # (限制对比度自适应直方图均衡化)  ####
            # ChannelDropout  # (随机drop一个或多个通道)
            # ChannelShuffle  # (通道打乱)
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.02),  # ColorJitter  # (色彩抖动【亮度、对比度、饱和度】)  ####
            # Defocus  # (虚焦)
            # Downscale  # (降质)
            A.Emboss(alpha=(0.2, 0.5), strength=(0.2, 0.7), p=0.05), # Emboss  # (浮雕效果) ####
            A.Equalize(mode='cv', by_channels=True, p=0.05), # Equalize  # (直方图均衡) ####
            # FDA  # (Fourier-Domain-Adaptation,实现简单的风格迁移)
            A.FancyPCA(alpha=0.1, p=0.02), # FancyPCA  # (RGB图像色彩增强)  ####
            # FromFloat  # (乘最大值变整型,与ToFloat相反)
            # GaussNoise  # (高斯噪声)
            # GaussianBlur  # (高斯模糊)
            # GlassBlur  # (玻璃模糊)
            # HistogramMatching  # (直方图匹配,会引起色调变化)
            A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.02), # HueSaturationValue  # (色调、饱和度、亮度)   ####
            A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), p=0.02), # ISONoise  # (传感器噪声)  ####
            # ImageCompression  # (图像压缩)
            # InvertImg  # (255-img)
            A.MedianBlur(blur_limit=5, p=0.05), # MedianBlur  # (中值滤波)  ####
            # MotionBlur  # (运动模糊)
            # MultiplicativeNoise  # (乘性噪声)
            # Normalize  # (归一化)
            # A.PixelDistributionAdaptation(p=0.02),# PixelDistributionAdaptation   #可能是调整图像中像素分布的技术    ####
            # Posterize  # (色调分层)
            # RGBShift  # (RGB每个通道上值偏移)
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.02),# RandomBrightnessContrast  # (亮度、对比度)     ####
            A.RandomFog(p=0.01),# RandomFog  # (雾效果)    ####
            A.RandomGamma(gamma_limit=(80, 120),p=0.02),# RandomGamma  # (gamma变换)   ####
            A.RandomRain(slant_lower=-10, slant_upper=10, drop_length=20,p=0.02),# RandomRain  # (下雨效果)    ####
            A.RandomShadow(num_shadows_lower=1, num_shadows_upper=3,p=0.02),# RandomShadow  # 随机阴影,模拟自然环境中的光影效果    ####
            A.RandomSnow(snow_point_lower=0.1, snow_point_upper=0.3,p=0.02), # RandomSnow  # 雪花效果,模拟冬季的气候条件    ####
            A.RandomSunFlare(flare_roi=(0, 0, 1, 1), angle_lower=0.3, angle_upper=0.7,p=0.02),# RandomSunFlare  # (太阳耀斑效果)     ####
            A.RandomToneCurve(scale=0.1,p=0.02),# RandomToneCurve  # 调整图像的色调曲线,改变图像的色彩和对比度。     ####
            # RingingOvershoot  # 在图像中引入振铃效果,常见于图像处理中的过度锐化。
            A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0),p=0.05),# Sharpen  # (锐化)    ####
            # Solarize  # (大于阈值反转)
            # Spatter  # (镜头雨点泥点飞溅效果)
            # Superpixels  # (超像素)
            # TemplateTransform  # 模板转换,通过一定的模板改变图像的一部分。
            # ToFloat  # (除最大值归一化,与FromFloat相反)
            # ToGray  # (转灰度(三通道))
            # ToRGB  # (灰度转三通道RGB)
            # ToSepia  # (加棕褐色滤镜)
            # A.UnsharpMask(blur_limit=(3,7),p=0.05),# UnsharpMask  # (锐化)      ####
            # ZoomBlur  # (变焦模糊)

            # # 2、Spatial-level transforms(包含了几何变换和空间变换。但是没有合成,比如cutout、mosaic)
            # Affine  # 仿射变换,包括缩放、旋转、平移
            # BBoxSafeRandomCrop  # (包含所有bboxes的裁剪)
            # CenterCrop  # (裁剪中心区域)
            # CoarseDropout  # (矩形区域cutout)
            # Crop  # (裁切)
            # CropAndPad  # (裁剪或填充图像边缘)
            # CropNonEmptyMaskIfExists  # (裁剪+缩放,可以忽略mask部分区域)
            # ElasticTransform  # (弹性变换)
            # Flip  # (翻转)
            # GridDistortion  # (网格畸变)
            # GridDropout  # (网格状cutout)
            # HorizontalFlip  # (水平翻转)
            # Lambda  # 应用自定义函数进行图像变换,为高度自定义的处理提供接口。后续添加合成(cutout、mixup、mosaic)变换使用
            # LongestMaxSize  # (长边等比例缩放至指定size)
            # MaskDropout  # (随机抹除目标实例)
            # NoOp  # (无操作)
            # OpticalDistortion  # (光学畸变(桶形、枕形))
            # PadIfNeeded  # (边界填充)
            # Perspective  # (透视变换)
            # PiecewiseAffine  # (局部仿射变换,效果类似ElasticTransform,但速度很慢)
            # PixelDropout  # (随机丢弃像素值)
            # RandomCrop  # (随机裁剪)
            # RandomCropFromBorders  # (图像边缘裁剪,会改变尺寸)
            # RandomCropNearBBox  # (指定rect附近裁剪)
            # RandomGridShuffle  # (分块打乱)
            # RandomResizedCrop  # (裁剪+缩放,裁剪区域宽高比随机)
            # RandomRotate90  # (随机旋转90度n次,即0°,90°,180°,270°随机旋转)

            A.OneOf([
                 A.RGBShift(r_shift_limit=50, g_shift_limit=50, b_shift_limit=50, p=0.5),
                 A.ChannelShuffle(p=0.05),  # 随机排列通道
                 # A.ColorJitter(p=0.3),  # 随机改变图像的亮度、对比度、饱和度、色调
                 A.ChannelDropout(p=0.05),  # 随机丢弃通道
             ], p=0.05),
             # A.Downscale(p=0.1),  # 随机缩小和放大来降低图像质量
             # A.Emboss(p=0.2),  # 压印输入图像并将结果与原始图像叠加
        ],
             # yolo: [x_center, y_center, width, height]  # 经过归一化
             # min_area: 表示bbox占据的像素总个数, 当数据增强后, 若bbox小于这个值则从返回的bbox列表删除该bbox.
             # min_visibility: 值域为[0,1], 如果增强后的bbox面积和增强前的bbox面积比值小于该值, 则删除该bbox
            bbox_params=A.BboxParams(format='yolo', min_area=0., min_visibility=0., label_fields=['category_id'],clip=True)
        )


    def aug_image(self):
        """
        对原始图像进行数据增强,并将增强后的图像和标签保存到指定路径。
        """
        # 确保保存增强图像和标签的目录存在
        os.makedirs(self.aug_save_image_path, exist_ok=True)
        os.makedirs(self.aug_save_label_path, exist_ok=True)
        # 遍历原始图像文件夹中的所有图像文件
        for image_filename in tqdm(os.listdir(self.pre_image_path), desc="Augmenting Images", unit="image"):
            # 构造图像和标签的完整路径
            image_path = os.path.join(self.pre_image_path, image_filename)
            label_path = os.path.join(self.pre_label_path, image_filename.replace('.jpg', '.txt'))
            # 如果标签文件不存在,则跳过
            if not os.path.exists(label_path):
                continue
            # 读取图片和标签
            image = cv2.imread(image_path)
            with open(label_path, 'r') as file:
                lines = file.readlines()
            # 解析标签文件中的边界框信息
            bboxes = [line.strip().split() for line in lines]
            category_ids = [int(bbox[0]) for bbox in bboxes]
            bboxes = [[float(x) for x in bbox[1:]] for bbox in bboxes]
            for i in range(5):  # 假设我们为每张原图生成5张增强图
                # 应用数据增强
                augmented = self.aug(image=image, bboxes=bboxes, category_id=category_ids)
                new_image = augmented['image']
                new_bboxes = augmented['bboxes']
                new_category_ids = augmented['category_id']
                # 使用uuid生成增强图像和标签文件的唯一名称
                unique_id = uuid.uuid4().hex
                new_image_filename = f"{image_filename.split('.')[0]}_{unique_id}.jpg"
                new_label_filename = f"{image_filename.split('.')[0]}_{unique_id}.txt"
                # 保存增强后的图片和标签
                cv2.imwrite(os.path.join(self.aug_save_image_path, new_image_filename), new_image)
                with open(os.path.join(self.aug_save_label_path, new_label_filename), 'w') as new_label_file:
                    for category_id, bbox in zip(new_category_ids, new_bboxes):
                        new_label_file.write(f"{category_id} {' '.join([f'{b:.6f}' for b in bbox])}\n")


def main():
    """
    主函数,用于创建YOLOAug类的实例并调用数据增强方法。
    """
    # 创建YOLOAug对象,设置数据增强的参数
    yolo_aug = YOLOAug(
        pre_image_path= ,  #数据增强之前的数据集image目录路径
        pre_label_path= ,  #数据增强之前的数据集label目录路径
        aug_save_image_path= , #数据增强之后的数据集image目录路径
        aug_save_label_path= , #数据增强之后的数据集label目录路径
        labels=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
        # classes_to_augment=[0,1,2,3,4,5,6,7,8,9],
        # target_counts={0:30000,1:30000,2:30000,3:30000,4:30000,5:30000,6:30000,7:30000,8:30000,9:30000}

        classes_to_augment = [6,7],
        target_counts = {6:1000,7:1000}

    )
    # 调用数据增强方法
    yolo_aug.aug_image()


# 程序入口,调用main函数
if __name__ == "__main__":
    main()

                
  • 7
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值