数据增强在视觉任务中算是一种基本都训练技巧,例如对于分类任务而言,我们常常能看到如下形式的代码(pytorch):
from torchvision import transforms
self.img_transform = transforms.Compose([
transforms.RandomRotation(90),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
而到了分割任务上,由于分割的标签本身也是图片,因此同样需要进行transform。一个直观的做法是将对图片所进行的增强直接给照搬到标签上来,比如:
from torchvision import transforms
self.img_transform = transforms.Compose([
transforms.RandomRotation(90),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
self.gt_transform = transforms.Compose([
transforms.RandomRotation(90),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Resize((self.trainsize, self.trainsize)),
transforms.ToTensor()
])
但是,这种写法实际上是错误的。原因在于,许多数据增强本身是概率性的,例如transforms.RandomVerticalFlip(p=0.5)
表示以50%的概率对输入图像进行垂直翻转。而对图像和标签应用两个独立的概率,就可能会出现图像翻转了而标签没翻转的不匹配情况。
正确的解法对各种概率变换进行单独的封装,将概率给抽离出来,例如:
import torchvision.transforms.functional as F
import random
class RandomVerticalFlip(object):
def __init__(self, p=0.5):
self.p = p
def __call__(self, data):
image, label = data['image'], data['label']
if random.random() < self.p:
return {'image': F.vflip(image), 'label': F.vflip(label)}
return {'image': image, 'label': label}
data = {'image': image, 'label': label}
data = RandomVerticalFlip()(data)