Python图像处理库 - Albumentations,可用于深度学习中网络训练时的图片数据增强.
Albumentations 图像数据增强库特点:
基于高度优化的 OpenCV 库实现图像快速数据增强.
针对不同图像任务,如分割,检测等,超级简单的 API 接口.
易于个性化定制.
易于添加到其它框架,比如 PyTorch.
1. Albumentations 的 pip 安装
sudo pip install albumentations #或 sudo pip install -U git+https://github.com/albu/albumentations
2. 不同图片数据增强库对比
对 ImageNet validation set 中的前 2000 张图片进行处理,采用 Intel Core i7-7800X CPU.
不同数据增强库的处理速度对比(以秒为单位,时间越少越好).
3. 使用示例
importnumpy as npimportcv2from matplotlib importpyplot as pltfrom albumentations import(
HorizontalFlip, IAAPerspective, ShiftScaleRotate, CLAHE, RandomRotate90,
Transpose, ShiftScaleRotate, Blur, OpticalDistortion, GridDistortion, HueSaturationValue,
IAAAdditiveGaussianNoise, GaussNoise, MotionBlur, MedianBlur, IAAPiecewiseAffine,
IAASharpen, IAAEmboss, RandomContrast, RandomBrightness, Flip, OneOf, Compose
)#图像变换函数
image= cv2.imread('test.jpg', 1) #BGR
image =cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
aug= HorizontalFlip(p=1)
img_HorizontalFlip= aug(image=image)['image']
aug= IAAPerspective(scale=0.2, p=1)
img_IAAPerspective= aug(image=image)['image']
aug= ShiftScaleRotate(p=1)
img_ShiftScaleRotate= aug(image=image)['image']def augment_flips_color(p=.5):returnCompose([
CLAHE(),
RandomRotate90(),
Transpose(),
ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.50, rotate_limit=45, p=.75),
Blur(blur_limit=3),
OpticalDistortion(),
GridDistortion(),
HueSaturationValue()
], p=p)
aug= augment_flips_color(p=1)
img_augment_flips_color= aug(image=image)['image']def strong_aug(p=.5):returnCompose([
RandomRotate90(),
Flip(),
Transpose(),
OneOf([
IAAAdditiveGaussianNoise(),
GaussNoise(),
], p=0.2),
OneOf([
MotionBlur(p=.2),
MedianBlur(blur_limit=3, p=.1),
Blur(blur_limit=3, p=.1),
], p=0.2),
ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=.2),
OneOf([
OpticalDistortion(p=0.3),
GridDistortion(p=.1),
IAAPiecewiseAffine(p=0.3),
], p=0.2),
OneOf([
CLAHE(clip_limit=2),
IAASharpen(),
IAAEmboss(),
RandomContrast(),
RandomBrightness(),
], p=0.3),
HueSaturationValue(p=0.3),
], p=p)
aug== strong_aug(p=1)
img_strong_aug= aug(image=image)['image']#show
plt.subplot(2, 3, 1)
plt.imshow(image)
plt.subplot(2, 3, 2)
plt.imshow(img_HorizontalFlip)
plt.subplot(2, 3, 3)
plt.imshow(img_IAAPerspective)
plt.subplot(2, 3, 4)
plt.imshow(img_ShiftScaleRotate)
plt.subplot(2, 3, 5)
plt.imshow(img_augment_flips_color)
plt.subplot(2, 3, 6)
plt.imshow(img_strong_aug)
plt.show()
from albumentations import(
RandomRotate90, Transpose, ShiftScaleRotate, Blur,
OpticalDistortion, CLAHE, GaussNoise, MotionBlur,
GridDistortion, HueSaturationValue, IAAAdditiveGaussianNoise,
MedianBlur, IAAPiecewiseAffine, IAASharpen, IAAEmboss,
RandomContrast, RandomBrightness, Flip, OneOf, Compose
)importnumpy as npdef strong_aug(p=0.5):returnCompose([
RandomRotate90(),
Flip(),
Transpose(),
OneOf([
IAAAdditiveGaussianNoise(),
GaussNoise(),
], p=0.2),
OneOf([
MotionBlur(p=0.2),
MedianBlur(blur_limit=3, p=0.1),
Blur(blur_limit=3, p=0.1),
], p=0.2),
ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=0.2),
OneOf([
OpticalDistortion(p=0.3),
GridDistortion(p=0.1),
IAAPiecewiseAffine(p=0.3),
], p=0.2),
OneOf([
CLAHE(clip_limit=2),
IAASharpen(),
IAAEmboss(),
RandomContrast(),
RandomBrightness(),
], p=0.3),
HueSaturationValue(p=0.3),
], p=p)
image= np.ones((300, 300, 3), dtype=np.uint8)
mask= np.ones((300, 300), dtype=np.uint8)
whatever_data= "my name"augmentation= strong_aug(p=0.9)
data= {"image": image, "mask": mask, "whatever_data": whatever_data, "additional": "hello"}
augmented= augmentation(**data) ## 数据增强
image, mask, whatever_data, additional = augmented["image"], augmented["mask"], augmented["whatever_data"], augmented["additional"]
4. 更新的使用示例
4.1 综合示例 - showcase
#导入相关库,并定义用于可视化的函数#!--*-- coding: utf-8 --*--
importosimportnumpy as npimportcv2from matplotlib importpyplot as pltfrom skimage.color importlabel2rgbimportalbumentations as Aimportrandom
BOX_COLOR= (255, 0, 0)
TEXT_COLOR= (255, 255, 255)def visualize_bbox(img, bbox, color=BOX_COLOR, thickness=2, **kwargs):#height, width = img.shape[:2]
x_min, y_min, w, h =bbox
x_min, x_max, y_min, y_max= int(x_min), int(x_min + w), int(y_min), int(y_min +h)
cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=color, thickness=thickness)returnimgdef visualize_titles(img, bbox, title, color=BOX_COLOR, thickness=2, font_thickness = 2, font_scale=0.35, **kwargs):#height, width = img.shape[:2]
x_min, y_min, w, h =bbox
x_min, x_max, y_min, y_max= int(x_min), int(x_min + w), int(y_min), int(y_min +h)
((text_width, text_height), _)=cv2.getTextSize(title, cv2.FONT_HERSHEY_