坚持写博客💪,分享自己的在学习、工作中的所得
- 给自己做备忘
- 对知识点记录、总结,加深理解
- 给有需要的人一些帮助,少踩一个坑,多走几步路
尽量以合适的方式排版,图文兼有
如果写的有误,或者有不理解的,均可在评论区留言
如果内容对你有帮助,欢迎点赞 👍 收藏 ⭐留言 📝。
虽然平台并不会有任何奖励,但是我会很开心,可以让我保持写博客的热情🙏🙏🙏
深度学习中,通常需要大量的数据,但是很多时候并不能得到足够的数据量,这时就可以使用数据增强来对数据集进行扩充。
常见的增强的库有:
- Augmentor: https://github.com/mdbloice/Augmentor
- imgaug: https://github.com/aleju/imgaug
- albumentations: https://github.com/albumentations-team/albumentations
albumentations
是其中功能较多,速度相比其他较快的一个库。对比
🧁albumentations 自定义增强操作
在使用albumentations做数据增强时,有时某些操作可能不能满足特定需求,这时就需要自定义操作了。
主要思路是先实现操作的方法,即要达到的功能,然后实现自定义操作类去继承albumentations特定的类,即可像albumentations自带方法一样使用。
在albumentations中实现了针对各种不同深度学习任务的增强操作,比如分类、分割、检测、关键点。这里使用分割为例,其他也类似。
🧁自带操作的使用
首先先了解一下自带操作的使用,以OpticalDistortion
操作为例再引入自定义操作
import albumentations as A
from albumentations import DualTransform
import numpy as np
from PIL import Image
import random
import cv2
def read_img_pillow(path):
with open(path, "rb") as f:
img = Image.open(f)
img.convert("RGB")
return np.array(img)
transform = A.OpticalDistortion(distort_limit=0.5, shift_limit=0.1, border_mode=0, p=1)
image = read_img_pillow(r"input.png")
transformed = transform(image=image)
transformed_image = transformed["image"]
print(image.shape, transformed_image.shape)
# (181, 180, 4) (181, 180, 4)
Image.fromarray(transformed_image).show()
input.png
output.png
🧁查看源码
上面是一个模拟镜头光学畸变的操作,代码在transforms.py
class OpticalDistortion(DualTransform):
def __init__(
self,
distort_limit=0.05,
shift_limit=0.05,
interpolation=cv2.INTER_LINEAR,
border_mode=cv2.BORDER_REFLECT_101,
value=None,
mask_value=None,
always_apply=False,
p=0.5,
):
super(OpticalDistortion, self).__init__(always_apply, p)
self.shift_limit = to_tuple(shift_limit)
self.distort_limit = to_tuple(distort_limit)
self.interpolation = interpolation
self.border_mode = border_mode
self.value = value
self.mask_value = mask_value
def apply(self, img, k=0, dx=0, dy=0, interpolation=cv2.INTER_LINEAR, **params):
return F.optical_distortion(img, k, dx, dy, interpolation, self.border_mode, self.value)
def apply_to_mask(self, img, k=0, dx=0, dy=0, **params):
return F.optical_distortion(img, k, dx, dy, cv2.INTER_NEAREST, self.border_mode, self.mask_value)
def get_params(self):
return {
"k": random.uniform(self.distort_limit[0], self.distort_limit[1]),
"dx": round(random.uniform(self.shift_limit[0], self.shift_limit[1])),
"dy": round(random.uniform(self.shift_limit[0], self.shift_limit[1])),
}
def get_transform_init_args_names(self):
return (
"distort_limit",
"shift_limit",
"interpolation",
"border_mode",
"value",
"mask_value",
)
可以发现其具体实现是F.optical_distortion()
functional.py
def optical_distortion( img, k=0, dx=0, dy=0, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REFLECT_101, value=None,):
height, width = img.shape[:2]
fx = width
fy = height
cx = width * 0.5 + dx
cy = height * 0.5 + dy
camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
distortion = np.array([k, k, 0, 0, 0], dtype=np.float32)
map1, map2 = cv2.initUndistortRectifyMap(camera_matrix, distortion, None, None, (width, height), cv2.CV_32FC1)
img = cv2.remap(img, map1, map2, interpolation=interpolation, borderMode=border_mode, borderValue=value,)
return img
这样一分析,我们只需要实现属于我们自己的class OpticalDistortion
,并实现对应功能的optical_distortion()
方法即可。
🧁实现自定义增强操作
我这里实现了一个特定效果的透视变换
🧁实现操作
def random_persp(img, direction=''): # 302 µs
# 随机透视变换
"""
img: RGB图像
direction: 方向
"""
rows, cols, _ = img.shape
# 原图的4个角点
pts1 = np.float32([[0, 0], [cols, 0], [cols, rows], [0, rows]])
if not direction:
direction = np.random.choice(['left', 'right'])
if direction == 'left':
n = np.random.randint(5, 40)
top_left = [n, 0]
bottom_left = [n, rows]
top_right = [cols-n, min(n, 50)]
bottom_right = [cols-n, rows-n//2]
else:
n = np.random.randint(5, 40)
top_left = [n, min(n, 50)]
bottom_left = [n, rows-n//2]
top_right = [cols-n, 0]
bottom_right = [cols-n, rows]
# pts1变换之后的新坐标
pts2 = np.float32([top_left, top_right, bottom_right, bottom_left])
# 生成变换矩阵
M = cv2.getPerspectiveTransform(pts1, pts2)
# 进行透视变换
dst = cv2.warpPerspective(src=img, M=M, dsize=(cols, rows)) # dsize=(w,h)
return dst
讲一下这个方法实现的功能:即一个物体,你从左右两侧去看,得到的不同视图。
left_im = random_persp(img=image, direction='left')
Image.fromarray(left_im).show()
right_im = random_persp(img=image, direction='right')
Image.fromarray(right_im).show()
🧁实现类
实现自定义操作类并继承DualTransform
类。
class CustomRandomPersp(DualTransform):
def __init__(self, direction: str = '', always_apply: bool = False, p: float = 0.5):
super(CustomRandomPersp, self).__init__(always_apply, p)
self.direction = direction
def apply(self, img, **params) -> np.ndarray:
return random_persp(img, direction=self.direction)
想原生类一样使用
trans = CustomRandomPersp(direction='left', p=1)
transformed = trans(image=image)
transformed_image = transformed["image"]
Image.fromarray(transformed_image).show()
其他增强操作的类有:ImageOnlyTransform
等。
🧁更多实现
class DualTransform(BasicTransform):
"""Transform for segmentation task."""
@property
def targets(self) -> Dict[str, Callable]:
return {
"image": self.apply,
"mask": self.apply_to_mask,
"masks": self.apply_to_masks,
"bboxes": self.apply_to_bboxes,
"keypoints": self.apply_to_keypoints,
}
def apply_to_bbox(self, bbox, **params):
raise NotImplementedError("Method apply_to_bbox is not implemented in class " + self.__class__.__name__)
def apply_to_keypoint(self, keypoint, **params):
raise NotImplementedError("Method apply_to_keypoint is not implemented in class " + self.__class__.__name__)
def apply_to_bboxes(self, bboxes, **params):
return [self.apply_to_bbox(tuple(bbox[:4]), **params) + tuple(bbox[4:]) for bbox in bboxes]
def apply_to_keypoints(self, keypoints, **params):
return [self.apply_to_keypoint(tuple(keypoint[:4]), **params) + tuple(keypoint[4:]) for keypoint in keypoints]
def apply_to_mask(self, img, **params):
return self.apply(img, **{k: cv2.INTER_NEAREST if k == "interpolation" else v for k, v in params.items()})
def apply_to_masks(self, masks, **params):
return [self.apply_to_mask(mask, **params) for mask in masks]
上面的例子中,我们只使用了image
,即transformed["image"]
,如果还需要更多的功能,如mask
,那就需要实现apply_to_mask
。masks
、bboxes
、keypoints
也同样。
还有,如果在前面有注意到,我传入的图片是带有alpha
通道的,所以mask
的需求,这个例子是能够满足的。
Image.fromarray(transformed_image).getchannel(3).save(r'mask.png')
如果内容对你有帮助,或者觉得写的不错
🏳️🌈欢迎点赞 👍 收藏 ⭐留言 📝
有问题,请在评论区留言