yolov5-基于Albumentations的图像增强

Albumentations是一个快速灵活的图像增强库。该库广泛用于工业深度学习研究机器学习竞赛开源项目。Albumentations 是用 Python 编写的,它根据 MIT 许可证获得许可。源代码albumentations-team/albumentations: Fast and flexible image augmentation library. Paper about the library: https://www.mdpi.com/2078-2489/11/2/125 (github.com)

首先安装:

pip install -U albumentations

整理代码如下:

from albumentations import *
import os
import cv2
from tqdm import tqdm

os.environ['ALBUMENTATIONS_SKIP_VERSION_CHECK'] = '1'
 
class enhancement:
    def __init__(self, picture_path, label_path, save_img_path, save_lable_path):
        self.picture_name = sorted(os.listdir(picture_path))
        self.label_name = sorted(os.listdir(label_path))
        self.picture_path = [picture_path + i for i in self.picture_name]
        self.label_path = [label_path + i for i in self.label_name]
        self.save_img_path = save_img_path
        self.save_lable_path = save_lable_path
 
    def iter(self):
        batch_size = 10
        for index_bin in tqdm(range(0, len(self.picture_path), batch_size), desc='批次进度'):
            # print(index_bin)
            picture_batch = self.picture_path[index_bin:index_bin + batch_size]
            label_batch = self.label_path[index_bin:index_bin + batch_size]
            yield picture_batch, label_batch, [index_bin, index_bin + batch_size]
 
    def get_transform(self):
        '''
        这里修改需要图像增强的具体方法
        :return:
        '''
        transform = Compose([
            # 图像均值平滑滤波。
            Blur(blur_limit=7, always_apply=False, p=0.5),
            # VerticalFlip 水平翻转
            VerticalFlip(always_apply=False, p=0.5),
            # HorizontalFlip 垂直翻转
            HorizontalFlip(always_apply=False, p=1),
            Rotate(45),  # 旋转
            RandomSunFlare(flare_roit=(0, 0, 1, 0.5))  # 随机太阳耀斑
            # 中心裁剪
            # CenterCrop(200, 200, always_apply=False, p=1.0),
            # RandomFog(fog_coef_lower=0.3, fog_coef_upper=0.7, alpha_coef=0.08, always_apply=False, p=1),
            # RandomCrop(width=200, height=200)
            # 添加其他增强技术
        ], bbox_params=BboxParams(format='yolo', label_fields=['class_labels']))
        return transform
 
    def augmentations(self, image, bboxes, class_labels):
        transform = self.get_transform()
        transformed = transform(image=image, bboxes=bboxes, class_labels=class_labels)
        augmented_image = transformed['image']
        augmented_bboxes = transformed['bboxes']
        augmented_labels = transformed['class_labels']
        return augmented_image, augmented_bboxes, augmented_labels
 
    def augmented_image_bboxes(self, img_path, l_path):
        with open(l_path, 'r') as f:
            values = f.read()
            f.close()
        class_labels, original_bboxes = [], []
        values = [i.split(' ') for i in values.split('\n')[:-1]]
        for i in values:
            class_labels.append(int(i[0]))
            original_bboxes.append([float(i) for i in i[1:]])
        original_image = cv2.imread(img_path)
        augmented_image, augmented_bboxes, augmented_labels = self.augmentations(original_image, original_bboxes,
                                                                                 class_labels)
        return augmented_image, augmented_bboxes, augmented_labels, original_image
 
    def parsing_data(self, p_l_i):
        img_path, l_path, index = p_l_i[0], p_l_i[1], p_l_i[2]
        self.augmented_image, self.augmented_bboxes, augmented_labels, original_image = self.augmented_image_bboxes(
            img_path, l_path)
        data = []
        for l, d in zip(augmented_labels, self.augmented_bboxes):
            s = ' '.join(map(str, [l] + list(d)))
            data.append(s)
        data = '\n'.join(data)
        if augmented_labels:
            self.show_img()
            self.save_img_lable(data, self.augmented_image, self.save_img_path, self.save_lable_path, index)
        else:
            print(f'{self.picture_name[index]}该图片没有标签,不做保存')
 
    def save_img_lable(self, data, img, save_img_path, save_lable_path, index):
        cv2.imwrite(save_img_path + self.picture_name[index], img)
        with open(save_lable_path + self.label_name[index], 'w') as f:
            f.write(data)
            f.close()
 
    def __call__(self):
        for picture_batch, label_batch, index_bin in self.iter():
            list(map(self.parsing_data,
                     [(p, l, i) for p, l, i in zip(picture_batch, label_batch, range(index_bin[0], index_bin[1]))]))
 
    def show_img(self, boxe=True):
        '''
        boxe = True,则保存的图片会有标签框
        '''
        if boxe:
            for j in self.augmented_bboxes:
                x, y, w, h = j
                x1 = int((x - w / 2) * self.augmented_image.shape[1])
                y1 = int((y - h / 2) * self.augmented_image.shape[0])
                x2 = int((x + w / 2) * self.augmented_image.shape[1])
                y2 = int((y + h / 2) * self.augmented_image.shape[0])
                cv2.rectangle(self.augmented_image, (x1, y1), (x2, y2), (255, 0, 0), 2)
                cv2.rectangle(self.augmented_image, (x1, y1), (x2, y2), (255, 0, 0), 2)
        else:
            pass
        # cv2.imshow('Augmented Image', self.augmented_image)
        # cv2.waitKey(0)
        # cv2.destroyAllWindows()

# 标签文件整理---------------------------------
def json2txt(path = None):
    import natsort
    from PIL import Image
    import json

    original_path = path
    all_path = natsort.natsorted(os.listdir(original_path)) 
    img_path = [i for i in all_path if i.endswith(".jpg")]
    lab_path = [i for i in all_path if i.endswith(".jpg.json")]

    # 构建完整的文件路径
    full_img_path = [os.path.join(original_path, i) for i in img_path]
    full_lab_path = [os.path.join(original_path, l) for l in lab_path]

    # 保存图片
    for j in full_img_path:
        img = Image.open(j)
        # 假设你想将图片保存到新的目录下,例如 'new_images'
        new_path = os.path.join('DATA_Augmentation------------------------/original_img+label/images', os.path.basename(j))
        img.save(new_path)

    # 保存标签文件
    for j in full_lab_path:
        # 同样地,假设你想将标签文件保存到新的目录下,例如 'new_labels'
        new_path = os.path.join('DATA_Augmentation------------------------/original_img+label/labels', os.path.basename(j).replace(".jpg.json", ".txt"))
        # with open(j, 'r') as f:
        #     data = f.read()
        # with open(new_path, 'w') as f:
        #     f.write(data)
        
        with open(j, 'r') as fj:
            json_lab = json.load(fj)
            
        # 创建一个映射字典
        class_map = {"Stall": 0}
        w, h = img.size

        # 写入标签信息
        with open(new_path, "w") as ff:
            for obj in json_lab["objects"]:
                cl = obj["class"]
                if cl in class_map:
                    box_b = obj["coord"]
                    x1 = box_b[0][0]
                    y1 = box_b[0][1]
                    x2 = box_b[1][0]
                    y2 = box_b[1][1]
                    xx = (x1 + x2) / 2 / w
                    yy = (y1 + y2) / 2 / h
                    ww = (x2 - x1) / w
                    hh = (y2 - y1) / h
                    ff.write(f'{class_map[cl]} {xx:.6f} {yy:.6f} {ww:.6f} {hh:.6f}\n')
    
 
 
if __name__ == '__main__':
    path1 = r"DATA_Augmentation------------------------/original_img+label/originalimg"
    # json2txt(path1)
    

    
    

    # 原图片,标签的路径
    picture_path = 'DATA_Augmentation------------------------/original_img+label/images/'
    label_path = 'DATA_Augmentation------------------------/original_img+label/labels/'
    # 增强后的图片跟标签
    save_img_path = 'DATA_Augmentation------------------------/original_img+label/img_aug/'
    save_lable_path = 'DATA_Augmentation------------------------/original_img+label/lab_aug/'

    # # 检查路径是否存在,如果不存在则创建
    # if not os.path.exists(save_img_path):
    #     os.makedirs(save_img_path, exist_ok=True)

    # if not os.path.exists(save_label_path):
    #     os.makedirs(save_label_path, exist_ok=True)

    # from pathlib import Path
    # # 创建 Path 对象
    # img_path = Path(save_img_path)
    # label_path = Path(save_lable_path )

    # # 检查路径是否存在,如果不存在则创建
    # if not img_path.exists():
    #     img_path.mkdir(parents=True, exist_ok=True)

    # if not label_path.exists():
    #     label_path.mkdir(parents=True, exist_ok=True)
            
    aug = enhancement(picture_path=picture_path,
                    label_path=label_path,
                    save_img_path=save_img_path,
                    save_lable_path=save_lable_path)
    aug()

效果如下:

原文链接:基于Albumentations的图像增强,对yolo数据,生成增强后的图片以及标签。_yolov8中加入albumentations-CSDN博客

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值