语义分割数据增强——图像和标注同步增强

文章介绍了一个使用PyTorch进行语义分割数据增强的实现,针对细胞图像数据集,考虑到灰度图像的特点,代码中包含了尺寸调整、旋转、翻转等同步增强操作,并排除了不适用于灰度图像的颜色调整方法。提供的代码示例展示了如何在数据预处理阶段应用这些增强策略。
摘要由CSDN通过智能技术生成

语义分割数据增强

常见的数据增强方式

查看pytorchtorchvision的transformer中的源代码,我们可以看到具有以下数据增强方式:

__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale",
           "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
           "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
           "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
           "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize",
           "RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"]

其中常见的数据增强方式包括:旋转、垂直翻转、水平翻转、放缩、剪裁、归一化等。

语义分割和图像分类的数据增强差异在于:语义分割是对图像的每个像素进行分类,所以在进行某些数据增强时,需要对标注图像(mask)进行同步操作,如旋转、剪裁、翻转等。

查遍网上的一些教程,但是没有发现一个能够直接使用的pytorch数据增强方式,所以想自己写一个方便后续使用。

具体实现代码

我们这里以一个细胞语义分割数据集为例,由于该数据集是灰度图像,所以相对于彩色图像数据增强有一些差距,代码中注释了灰度图像不能使用的数据增强方式,但是彩色图像可以使用的数据增强方式。具体代码如下所示:

import numpy as np
import cv2
import torch
from torch.utils.data import Dataset
import os
from PIL import Image
from torchvision.transforms import functional as F
import random


class CellDataset(Dataset):
    def __init__(self, image_dir, mask_dir, names_list, image_size=224, isGray=False, augmentation=True):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.augmentation = augmentation
        self.names_list = names_list
        self.isGray = isGray
        self.image_size = image_size

    def __len__(self):
        return len(self.names_list)

    def augmentate(self, image, mask):
        # it is expected to be in [..., H, W] format
        image = torch.unsqueeze(torch.from_numpy(
            np.array(image, dtype=np.uint8)), dim=0)
        mask = torch.unsqueeze(torch.from_numpy(
            np.array(mask, dtype=np.uint8)), dim=0)

        image = F.resize(image, size=[self.image_size, self.image_size])
        mask = F.resize(mask, size=[self.image_size, self.image_size])
        # 彩色图可以进行以下数据增强,参数不太好调整
        image = F.adjust_gamma(image, gamma=random.uniform(0.8, 1.2))
        image = F.adjust_contrast(
            image, contrast_factor=random.uniform(0.8, 1.2))
        image = F.adjust_brightness(
            image, brightness_factor=random.uniform(0.8, 1.2))
        image = F.adjust_saturation(
            image, saturation_factor=random.uniform(0.8, 1.2))
        image = F.adjust_hue(image, hue_factor=random.uniform(-0.2, 0.2))

        # 让image和mask进行同步旋转和翻转数据增强
        image_mask = torch.cat([image, mask], dim=0)

        if random.uniform(0, 1) > 0.5:
            image_mask = F.hflip(image_mask)
        if random.uniform(0, 1) > 0.5:
            image_mask = F.vflip(image_mask)
        if random.uniform(0, 1) > 0.5:
            image_mask = F.rotate(image_mask, angle=90)

        # 要看image和mask的维度
        image = image_mask[0, ...]
        mask = image_mask[1, ...]
        # image = image / 255
        # mask = mask / 255

        # image = torch.unsqueeze(image, dim=0)
        # 标准化,彩色图像需要传三个值
        # image = F.normalize(image, mean=[0.5], std=[1])
        # mask = torch.unsqueeze(mask, dim=0)

        return image, mask

    def __getitem__(self, item):
        image_path = os.path.join(self.image_dir, self.names_list[item])
        mask_path = os.path.join(self.mask_dir, self.names_list[item])
        image = Image.open(image_path)
        if self.isGray:
            image = image.convert('L')
        mask = Image.open(mask_path)

        if self.augmentation:
            image, mask = self.augmentate(image, mask)
        return image, mask


if __name__ == '__main__':
    cell_dataset = CellDataset(image_dir='./data/image', mask_dir='./data/label',
                               names_list=['0.png'])
    index = 3
    for image, mask in cell_dataset:
        print(image.shape, mask.shape)
        print(torch.max(image), torch.min(image))
        image = np.array(image, dtype=np.uint8)
        mask = np.array(mask, dtype=np.uint8)
        # cv2.imshow('image', image)
        # cv2.imshow('mask', mask)
        # cv2.waitKey(0)
        cv2.imwrite(os.path.join(
            './data/augment/image', str(index)+'.png'), image)
        cv2.imwrite(os.path.join(
            './data/augment/label', str(index)+'.png'), mask)

我们以下图图像和图像标注掩码为例进行实验:

原始细胞图像

原始细胞掩码标签图像

我们更该index的值,进行重复执行程序,来生成多个不同对应的数据增强图像,我们重复执行了4次,得到了以下数据增强的细胞图像和对应的掩码标签图像。


如果有问题可以在评论区进行回复。如果对您有帮助的话可以帮忙点赞👍👍👍。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Trouble..

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值