多谢:https://www.cnblogs.com/ziytong/p/10791475.html
albumentations-检测任务-数据增强方法(三)
对于目标检测等任务会要求数据增强不仅仅返回增强之后的图片,还应该返回增强后的图片中目标的bounding box信息。
对目标检测中的增强,该库一共提供了两类方式的增强,一种是Pixel-level transforms,另一种是Spatial-level transforms。
1.首先定义可视化函数
# 导入相关库,定义可视化函数
import os
import numpy as np
import cv2
from matplotlib import pyplot as plt
from urllib.request import urlopen
from albumentations import (
HorizontalFlip,
VerticalFlip,
Resize,
CenterCrop,
RandomCrop,
Crop,
Compose
)
# 用于图片上的边界框和类别 labels 的可视化函数
BOX_COLOR = (255, 0, 0)
TEXT_COLOR = (255, 255, 255)
def visualize_bbox(img, bbox, class_id, class_idx_to_name, color=BOX_COLOR, thickness=2):
x_min, y_min, w, h = bbox
x_min, x_max, y_min, y_max = int(x_min), int(x_min + w), int(y_min), int(y_min + h)
cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=color, thickness=thickness)
class_name = class_idx_to_name[class_id]
((text_width, text_height), _) = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1)
cv2.rectangle(img, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), BOX_COLOR, -1)
cv2.putText(img, class_name, (x_min, y_min - int(0.3 * text_height)), cv2.FONT_HERSHEY_SIMPLEX, 0.35,TEXT_COLOR, lineType=cv2.LINE_AA)
return img
def visualize(annotations, category_id_to_name):
img = annotations['image'].copy()
for idx, bbox in enumerate(annotations['bboxes']):
img = visualize_bbox(img, bbox, annotations['category_id'][idx], category_id_to_name)
plt.figure(figsize=(12, 12))
plt.imshow(img)
plt.imshow()
2.检测任务
对于检测问题,必须以指定格式定义 bbox_params. 支持的格式有两种: coco 和 pascal_voc.
coco 的 bounding box 格式为:[x_min, y_min, width, height]
pascal_voc 的 bounding box 格式为: [x_min, y_min, x_max, y_max]
def get_aug(aug, min_area=0., min_visibility=0.):
return Compose(aug, bbox_params={'format': 'coco', 'min_area': min_area, 'min_visibility': min_visibility, 'label_fields': ['category_id']})
def download_image(url):
data = urlopen(url).read()
data = np.frombuffer(data, np.uint8)
image = cv2.imdecode(data, cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
image = download_image('http://images.cocodataset.org/train2017/000000386298.jpg')
# Annotations for image 386298 from COCO http://cocodataset.org/#explore?id=386298
annotations = {'image': image, 'bboxes': [[366.7, 80.84, 132.8, 181.84], [5.66, 138.95, 147.09, 164.88]], 'category_id': [18, 17]}
category_id_to_name = {17: 'cat', 18: 'dog'}
可视化原图标注:
visualize(annotations, category_id_to_name)
垂直翻转增强:
aug = get_aug([VerticalFlip(p=1)])
augmented = aug(**annotations)
visualize(augmented, category_id_to_name)
水平翻转增强:
aug = get_aug([HorizontalFlip(p=1)])
augmented = aug(**annotations)
visualize(augmented, category_id_to_name)
Resize 数据增强:
aug = get_aug([Resize(p=1, height=256, width=256)])
augmented = aug(**annotations)
visualize(augmented, category_id_to_name)
Albumentation 库还支持 boxes 裁剪与删除. 主要包括两个参数:min_aera 和 min_visibility.
默认 min_aera 和 min_visibility 值均为 0,故,只有超出图片尺寸之外的 boxes 才会被删除.
CenterCrop:
aug = get_aug([CenterCrop(p=1, height=300, width=300)])
augmented = aug(**annotations)
visualize(augmented, category_id_to_name)
CenterCrop with default filter:
aug = get_aug([CenterCrop(p=1, height=224, width=224)])
augmented = aug(**annotations)
print(augmented['category_id'])
visualize(augmented, category_id_to_name)
CenterCrop + filter with min_area:
aug = get_aug([CenterCrop(p=1, height=224, width=224)], min_area=4000)
augmented = aug(**annotations)
visualize(augmented, category_id_to_name)