数据增强——image和mask同样变换 的(Pytorch)实现方法

Applying the same augmentation with the same parameters to multiple images and masks.

在做深度学习任务时数据增强是必须的,很多时候我们希望对图片和对应的mask做相同的变换,比如语义分割任务中。我实验了两种实现方式。

1. Albumentations库

Albumentations是一个第三方库,提供了一个单一的界面来处理不同的计算机视觉任务,例如分类、语义分割、实例分割、对象检测、姿态估计等。使用它可以很轻易的实现我们的目的:

from albumentations.augmentations.transforms import Normalize
from albumentations.pytorch.transforms import ToTensorV2

# INPUT中已存入需要配置的参数
transform = A.Compose([
    A.Resize(height=INPUT.IMG_SIZE[0], width=INPUT.IMG_SIZE[1]),
    A.HorizontalFlip(p=INPUT.PROB),
    A.PadIfNeeded(min_height=INPUT.IMG_SIZE[0]+INPUT.PADDING*2, min_width=INPUT.IMG_SIZE[1]+INPUT.PADDING*2, border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0),
    A.RandomCrop(height=INPUT.IMG_SIZE[0], width=INPUT.IMG_SIZE[1]),
    Normalize(mean=INPUT.PIXEL_MEAN, std=INPUT.PIXEL_STD),
    ToTensorV2(),
])

2. torchvision实现

一般封装的torchvison transform方法如randomCrop等只能处理单张图片,无法处理image和mask成对出现的情况。因此需要对默认方法进行改写,以下是实现的几种增强方法:

import math
import numpy as np
import random
from PIL import Image

import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F


class Resize(object):
    def __init__(self, size):
        self.size = size

    def __call__(self, image, target=None):
        image = F.resize(image, self.size)
        if target is not None:
            target = F.resize(target, self.size, interpolation=F.InterpolationMode.NEAREST)
        return image, target


class RandomHorizontalFlip(object):
    def __init__(self, flip_prob):
        self.flip_prob = flip_prob

    def __call__(self, image, target=None):
        if random.random() < self.flip_prob:
            image = F.hflip(image)
            if target is not None:
                target = F.hflip(target)
        return image, target

class RandomCrop(object):
    def __init__(self, size):
        self.size = size

    def __call__(self, image, target):
        crop_params = T.RandomCrop.get_params(image, self.size)
        image = F.crop(image, *crop_params)
        if target is not None:
            target = F.crop(target, *crop_params)
        return image, target

class CenterCrop(object):
    def __init__(self, size):
        self.size = size

    def __call__(self, image, target):
        image = F.center_crop(image, self.size)
        if target is not None:
            target = F.center_crop(target, self.size)
        return image, target

class Normalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image, target):
        image = F.normalize(image, mean=self.mean, std=self.std)
        return image, target

class Pad(object):
    def __init__(self, padding_n, padding_fill_value=0, padding_fill_target_value=0):
        self.padding_n = padding_n
        self.padding_fill_value = padding_fill_value
        self.padding_fill_target_value = padding_fill_target_value

    def __call__(self, image, target):
        image = F.pad(image, self.padding_n, self.padding_fill_value)
        if target is not None:
            target = F.pad(target, self.padding_n, self.padding_fill_target_value)
        return image, target

class ToTensor(object):
    def __call__(self, image, target):
        image = F.to_tensor(image)
        if target is not None:
            target = torch.as_tensor(np.array(target), dtype=torch.int64)
        return image, target

但是只实现增强方法还不够,因为torchvision默认的Compose还不支持两个输入,改写如下:

class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, mask=None):
        for t in self.transforms:
            image, mask = t(image, mask)
        return {'image':image, 'mask':mask}


# 使用:
transform = Compose([
    Resize(INPUT.IMG_SIZE),
    RandomHorizontalFlip(flip_prob=INPUT.PROB),
    Pad(INPUT.PADDING, 0, 0),
    RandomCrop(INPUT.IMG_SIZE),
    ToTensor(),
    Normalize(mean=INPUT.PIXEL_MEAN, std=INPUT.PIXEL_STD)
])

然后替换torchvison的默认方法就可以了。对了,相应的Dataloader类也要根据增强的输出形式处理。

3. 两种方式结果对比

首先,这两种方法都可以达到我们的目的。

其次,Albumentations库是基于OpenCV实现的,而torchvison是基于PIL实现的,这会导致两种方法的处理结果可能会不同,比如resize。

第三,我做的person reid任务,两种方法都能达到目的。但是基于Albumentations的方法结果会比torchvison方法低两个点,查了半天不知道是什么原因。

  • 22
    点赞
  • 53
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值