Albumentations是一个用于图像增强的Python库,它提供了多种增强技术,包括随机裁剪、旋转、缩放、翻转、变形、颜色变换、模糊等操作。使用Albumentations库可以快速、高效地对图像数据进行增强,从而提升机器学习模型的鲁棒性。
本人根据非常棒的Albumentations数据增强库进行二次封装,将yolo数据生成增强后的标签跟图片,代码更改路径可直接调用。
from albumentations import *
import os
import cv2
from tqdm import tqdm
class enhancement:
def __init__(self, picture_path, label_path, save_img_path, save_lable_path):
self.picture_name = sorted(os.listdir(picture_path))
self.label_name = sorted(os.listdir(label_path))
self.picture_path = [picture_path + i for i in self.picture_name]
self.label_path = [label_path + i for i in self.label_name]
self.save_img_path = save_img_path
self.save_lable_path = save_lable_path
def iter(self):
batch_size = 10
for index_bin in tqdm(range(0, len(self.picture_path), batch_size), desc='批次进度'):
# print(index_bin)
picture_batch = self.picture_path[index_bin:index_bin + batch_size]
label_batch = self.label_path[index_bin:index_bin + batch_size]
yield picture_batch, label_batch, [index_bin, index_bin + batch_size]
def get_transform(self):
'''
这里修改需要图像增强的具体方法
:return:
'''
transform = Compose([
# 图像均值平滑滤波。
Blur(blur_limit=7, always_apply=False, p=0.5),
# VerticalFlip 水平翻转
VerticalFlip(always_apply=False, p=0.5),
# HorizontalFlip 垂直翻转
HorizontalFlip(always_apply=False, p=1),
# 中心裁剪
CenterCrop(200, 200, always_apply=False, p=1.0),
# RandomFog(fog_coef_lower=0.3, fog_coef_upper=0.7, alpha_coef=0.08, always_apply=False, p=1),
# RandomCrop(width=200, height=200)
# 添加其他增强技术
], bbox_params=BboxParams(format='yolo', label_fields=['class_labels']))
return transform
def augmentations(self, image, bboxes, class_labels):
transform = self.get_transform()
transformed = transform(image=image, bboxes=bboxes, class_labels=class_labels)
augmented_image = transformed['image']
augmented_bboxes = transformed['bboxes']
augmented_labels = transformed['class_labels']
return augmented_image, augmented_bboxes, augmented_labels
def augmented_image_bboxes(self, img_path, l_path):
with open(l_path, 'r') as f:
values = f.read()
f.close()
class_labels, original_bboxes = [], []
values = [i.split(' ') for i in values.split('\n')[:-1]]
for i in values:
class_labels.append(int(i[0]))
original_bboxes.append([float(i) for i in i[1:]])
original_image = cv2.imread(img_path)
augmented_image, augmented_bboxes, augmented_labels = self.augmentations(original_image, original_bboxes,
class_labels)
return augmented_image, augmented_bboxes, augmented_labels, original_image
def parsing_data(self, p_l_i):
img_path, l_path, index = p_l_i[0], p_l_i[1], p_l_i[2]
self.augmented_image, self.augmented_bboxes, augmented_labels, original_image = self.augmented_image_bboxes(
img_path, l_path)
data = []
for l, d in zip(augmented_labels, self.augmented_bboxes):
s = ' '.join(map(str, [l] + list(d)))
data.append(s)
data = '\n'.join(data)
if augmented_labels:
self.show_img()
self.save_img_lable(data, self.augmented_image, self.save_img_path, self.save_lable_path, index)
else:
print(f'{self.picture_name[index]}该图片没有标签,不做保存')
def save_img_lable(self, data, img, save_img_path, save_lable_path, index):
cv2.imwrite(save_img_path + self.picture_name[index], img)
with open(save_lable_path + self.label_name[index], 'w') as f:
f.write(data)
f.close()
def __call__(self):
for picture_batch, label_batch, index_bin in self.iter():
list(map(self.parsing_data,
[(p, l, i) for p, l, i in zip(picture_batch, label_batch, range(index_bin[0], index_bin[1]))]))
def show_img(self, boxe=True):
'''
boxe = True,则保存的图片会有标签框
'''
if boxe:
for j in self.augmented_bboxes:
x, y, w, h = j
x1 = int((x - w / 2) * self.augmented_image.shape[1])
y1 = int((y - h / 2) * self.augmented_image.shape[0])
x2 = int((x + w / 2) * self.augmented_image.shape[1])
y2 = int((y + h / 2) * self.augmented_image.shape[0])
cv2.rectangle(self.augmented_image, (x1, y1), (x2, y2), (255, 0, 0), 2)
cv2.rectangle(self.augmented_image, (x1, y1), (x2, y2), (255, 0, 0), 2)
else:
pass
# cv2.imshow('Augmented Image', self.augmented_image)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
if __name__ == '__main__':
# 原图片,标签的路径
picture_path = 'D:\\mydemo\\mask_data\\MaskDataset\\images\\train\\'
label_path = 'D:\\mydemo\\mask_data\\MaskDataset\\labels\\train\\'
# 增强后的图片跟标签
save_img_path = 'D:\\mydemo\\mask_data\\MaskDataset\\images\\dd\\'
save_lable_path = 'D:\\mydemo\mask_data\\MaskDataset\\images\\dd\\'
c = enhancement(picture_path=picture_path,
label_path=label_path,
save_img_path=save_img_path,
save_lable_path=save_lable_path)
c()
效果如下:
增强前后的标签以及图片变化: