语义分割数据增强
常见的数据增强方式
查看pytorch
torchvision的transformer中的源代码,我们可以看到具有以下数据增强方式:
__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale",
"CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
"RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize",
"RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"]
其中常见的数据增强方式包括:旋转、垂直翻转、水平翻转、放缩、剪裁、归一化等。
语义分割和图像分类的数据增强差异在于:语义分割是对图像的每个像素进行分类,所以在进行某些数据增强时,需要对标注图像(mask)进行同步操作,如旋转、剪裁、翻转等。
查遍网上的一些教程,但是没有发现一个能够直接使用的pytorch数据增强方式,所以想自己写一个方便后续使用。
具体实现代码
我们这里以一个细胞语义分割数据集为例,由于该数据集是灰度图像,所以相对于彩色图像数据增强有一些差距,代码中注释了灰度图像不能使用的数据增强方式,但是彩色图像可以使用的数据增强方式。具体代码如下所示:
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset
import os
from PIL import Image
from torchvision.transforms import functional as F
import random
class CellDataset(Dataset):
def __init__(self, image_dir, mask_dir, names_list, image_size=224, isGray=False, augmentation=True):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.augmentation = augmentation
self.names_list = names_list
self.isGray = isGray
self.image_size = image_size
def __len__(self):
return len(self.names_list)
def augmentate(self, image, mask):
# it is expected to be in [..., H, W] format
image = torch.unsqueeze(torch.from_numpy(
np.array(image, dtype=np.uint8)), dim=0)
mask = torch.unsqueeze(torch.from_numpy(
np.array(mask, dtype=np.uint8)), dim=0)
image = F.resize(image, size=[self.image_size, self.image_size])
mask = F.resize(mask, size=[self.image_size, self.image_size])
# 彩色图可以进行以下数据增强,参数不太好调整
image = F.adjust_gamma(image, gamma=random.uniform(0.8, 1.2))
image = F.adjust_contrast(
image, contrast_factor=random.uniform(0.8, 1.2))
image = F.adjust_brightness(
image, brightness_factor=random.uniform(0.8, 1.2))
image = F.adjust_saturation(
image, saturation_factor=random.uniform(0.8, 1.2))
image = F.adjust_hue(image, hue_factor=random.uniform(-0.2, 0.2))
# 让image和mask进行同步旋转和翻转数据增强
image_mask = torch.cat([image, mask], dim=0)
if random.uniform(0, 1) > 0.5:
image_mask = F.hflip(image_mask)
if random.uniform(0, 1) > 0.5:
image_mask = F.vflip(image_mask)
if random.uniform(0, 1) > 0.5:
image_mask = F.rotate(image_mask, angle=90)
# 要看image和mask的维度
image = image_mask[0, ...]
mask = image_mask[1, ...]
# image = image / 255
# mask = mask / 255
# image = torch.unsqueeze(image, dim=0)
# 标准化,彩色图像需要传三个值
# image = F.normalize(image, mean=[0.5], std=[1])
# mask = torch.unsqueeze(mask, dim=0)
return image, mask
def __getitem__(self, item):
image_path = os.path.join(self.image_dir, self.names_list[item])
mask_path = os.path.join(self.mask_dir, self.names_list[item])
image = Image.open(image_path)
if self.isGray:
image = image.convert('L')
mask = Image.open(mask_path)
if self.augmentation:
image, mask = self.augmentate(image, mask)
return image, mask
if __name__ == '__main__':
cell_dataset = CellDataset(image_dir='./data/image', mask_dir='./data/label',
names_list=['0.png'])
index = 3
for image, mask in cell_dataset:
print(image.shape, mask.shape)
print(torch.max(image), torch.min(image))
image = np.array(image, dtype=np.uint8)
mask = np.array(mask, dtype=np.uint8)
# cv2.imshow('image', image)
# cv2.imshow('mask', mask)
# cv2.waitKey(0)
cv2.imwrite(os.path.join(
'./data/augment/image', str(index)+'.png'), image)
cv2.imwrite(os.path.join(
'./data/augment/label', str(index)+'.png'), mask)
我们以下图图像和图像标注掩码为例进行实验:
我们更该index
的值,进行重复执行程序,来生成多个不同对应的数据增强图像,我们重复执行了4次,得到了以下数据增强的细胞图像和对应的掩码标签图像。
如果有问题可以在评论区进行回复。如果对您有帮助的话可以帮忙点赞👍👍👍。