数据增强系列(6)使用Albumentations进行关键点增强

在本手册中,我们将展示如何将Albumentations应用于关键点增强问题。您可以对具有关键点的图像使用任何像素级增强,因为像素级增强不会影响关键点。

注意:默认情况下,与关键点一起工作的扩展不会在转换后改变关键点的标签。如果关键点的标签是特异性的,这可能会造成问题。例如,如果您有一个名为left arm的关键点,并应用一个HorizontalFlip增强,您将得到一个具有相同左臂标签的关键点,但它现在看起来像一个右臂关键点。

如果您使用这种类型的关键点,考虑使用来自albumentations-experimentalSymmetricKeypoints扩展—正是为了处理这种情况而创建的实验性的扩展。pip install -U albumentations_experimental from albumentations_experimental import FlipSymmetricKeypoints

1.导入相关包

import random
import cv2
from matplotlib import pyplot as plt
import albumentations as A

KEYPOINT_COLOR = (0, 255, 0)  # Green

2.定义一个在图像上可视化关键点的函数

def vis_keypoints(image, keypoints, color=KEYPOINT_COLOR, diameter=15):
    image = image.copy()

    for (x, y) in keypoints:
        cv2.circle(image, (int(x), int(y)), diameter, (0, 255, 0), -1)

    plt.figure(figsize=(8, 8))
    plt.axis('off')
    plt.imshow(image)

3.获得一个图像和它的注释

我们将对关键点的坐标使用xy格式。每个关键点用两个坐标定义,x是x轴上的位置,y是y轴上的位置。

image = cv2.imread('keypoints_image.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
keypoints = [
    (100, 100),
    (720, 410),
    (1100, 400),
    (1700, 30),
    (300, 650),
    (1570, 590),
    (560, 800),
    (1300, 750),
    (900, 1000),
    (910, 780),
    (670, 670),
    (830, 670),
    (1000, 670),
    (1150, 670),
    (820, 900),
    (1000, 900),
]

4.用关键点可视化原始图像

vis_keypoints(image, keypoints)

在这里插入图片描述

5.定义一个简单的数据增强管道

transform = A.Compose(
    [A.HorizontalFlip(p=1)],
    keypoint_params=A.KeypointParams(format='xy')
)
transformed = transform(image=image, keypoints=keypoints)
vis_keypoints(transformed['image'], transformed['keypoints'])

在这里插入图片描述

6.下面是一些数据增强管道的例子

transform = A.Compose(
    [A.VerticalFlip(p=1)],
    keypoint_params=A.KeypointParams(format='xy')
)
transformed = transform(image=image, keypoints=keypoints)
vis_keypoints(transformed['image'], transformed['keypoints'])

在这里插入图片描述

# 为了可视化的目的,我们固定了随机种子,因此增强将总是产生相同的结果。在真实的计算机视觉管道中,
# 您不应该在对图像应用变换之前固定随机种子,因为在这种情况下,管道总是输出相同的图像。图像增强的目的是每次使用不同的变换。
random.seed(7)
transform = A.Compose(
    [A.RandomCrop(width=768, height=768, p=1)],
    keypoint_params=A.KeypointParams(format='xy')
)
transformed = transform(image=image, keypoints=keypoints)
vis_keypoints(transformed['image'], transformed['keypoints'])

在这里插入图片描述

random.seed(7)
transform = A.Compose(
    [A.Rotate(p=0.5)],
    keypoint_params=A.KeypointParams(format='xy')
)
transformed = transform(image=image, keypoints=keypoints)
vis_keypoints(transformed['image'], transformed['keypoints'])

在这里插入图片描述

transform = A.Compose(
    [A.CenterCrop(height=512, width=512, p=1)],
    keypoint_params=A.KeypointParams(format='xy')
)
transformed = transform(image=image, keypoints=keypoints)
vis_keypoints(transformed['image'], transformed['keypoints'])

在这里插入图片描述

random.seed(7)
transform = A.Compose(
    [A.ShiftScaleRotate(p=0.5)],
    keypoint_params=A.KeypointParams(format='xy')
)
transformed = transform(image=image, keypoints=keypoints)
vis_keypoints(transformed['image'], transformed['keypoints'])

在这里插入图片描述

7.一个复杂的增强管道的例子

random.seed(7)
transform = A.Compose([
    A.RandomSizedCrop(min_max_height=(256, 1025), height=512, width=512, p=0.5),
    A.HorizontalFlip(p=0.5),
    A.OneOf([
        A.HueSaturationValue(p=0.5),
        A.RGBShift(p=0.7)
    ], p=1),
    A.RandomBrightnessContrast(p=0.5)
],
    keypoint_params=A.KeypointParams(format='xy'),
)
transformed = transform(image=image, keypoints=keypoints)
vis_keypoints(transformed['image'], transformed['keypoints'])

在这里插入图片描述

8.BONUS:Keras的数据增强

import numpy as np
import imageio
import os
import matplotlib.pyplot as plt
import pandas as pd
import albumentations as A
import cv2
import json
from tensorflow.python.keras.utils.data_utils import Sequence

def extract_coordinates(df):
    full_coordinates = df['region_shape_attributes']
    ls_coordinates = []
    for coordinates in full_coordinates:
        coordinates = json.loads(coordinates)
        ls_coordinates.append([coordinates['cx'], coordinates['cy']])
    return np.array(ls_coordinates, dtype=np.float32)

def rescale_image(image):
    return (image / np.max(image) * 255.).astype(np.float32)

class CustomSeq(Sequence):
    def __init__(self, path2imgs, df, batch_size, augmentations=None, mode='train'):
        self.path2imgs = path2imgs
        self.df = df
        self.img_list = self.df['filename']
        self.y = extract_coordinates(self.df)
        self.batch_size = batch_size
        self.augmentations = augmentations
        self.mode = mode.lower()
        
    def __len__(self):
        return int(np.ceil(len(self.df) / float(self.batch_size)))
    
    def on_epoch_end(self):
        self.indexes = range(len(self.img_list))
        if self.mode == 'train':
            self.indexes = random.sample(self.indexes, k=len(self.indexes))
    
    def get_batch_labels(self, idx, shapes):
        y_batch = self.y[idx * self.batch_size: (idx+1) * self.batch_size]
        return y_batch
    
    def get_batch_images(self, idx):
        x_batch = []
        shapes = []
        img_names = self.img_list[idx * self.batch_size: (idx+1) * self.batch_size]
        for img_name in img_names:
            image = imageio.imread(os.path.join(self.path2imgs, img_name))
            image = rescale_image(image)
            x_batch.append(image)
            shapes.append(image.shape)
        return x_batch, np.array(shapes)
    
    def __getitem__(self, idx):
        x_batch, shapes = self.get_batch_images(idx)
        y_batch = self.get_batch_labels(idx, shapes)
        if self.augmentations:
            # walk around images and keypoints
            for i, (x_item, y_item) in enumerate(zip(x_batch, y_batch)):
                transformed = self.augmentations(image=x_item, keypoints=np.expand_dims(y_item, axis=0))
                # Rewrite image and keypoints values in not augmented batch
                x_batch[i], y_batch[i] = transformed['image'], transformed['keypoints'][0]
        return x_batch, y_batch


transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    #A.InvertImg(p=0.5),
    #A.ShiftScaleRotate(shift_limit=0, scale_limit=0.1, rotate_limit=15, p=0.5),
    A.ToFloat(max_value=255.)
], keypoint_params=A.KeypointParams(format='xy'))

path2png_imgs = os.getcwd()
df = pd.read_csv('vgg_annotate_crop.csv', header=0)

data = CustomSeq(path2png_imgs, df, 1, augmentations=transform)

images, point = data.__getitem__(0)

images[0] = cv2.circle(images[0], list(map(tuple, point.astype(np.int).tolist()))[0], 30, (1), -1)
plt.figure(figsize=(10, 10))
plt.imshow(images[0], cmap='gray')

参考目录

https://github.com/albumentations-team/albumentations_examples/blob/master/notebooks/example_keypoints.ipynb

  • 2
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值