数据增强(Data Augmentation)是一种在训练集上生成新样本的方法,计算机视觉(CV)和自然语言处理 (NLP) 模型中经常使用数据增强,旨在通过创造变体来增加训练数据的多样性,这些变体保留了原始数据的主要信息,但呈现出不同的表达形式。
数据增强的优势:
-
提升模型泛化能力:数据增强通过生成额外的训练样本,使模型能够在更多变体下进行学习,这有助于模型在未见过的数据上表现得更好。
-
减少过拟合:通过增加训练数据量,模型有更少的机会对特定训练实例产生过度依赖,从而降低了过拟合的风险。
-
成本效益高:相比于收集更多的真实数据,数据增强是一种相对便宜且有效的方法来扩充数据集。
-
模拟真实场景:在计算机视觉中,数据增强可以模拟不同的视角、光照条件、遮挡等,使得模型更加适应实际环境中的变化。
-
加速模型收敛:在某些情况下,数据增强可以改善模型的收敛速度,因为模型能够接触到更全面的数据分布。
CV的数据增强方法:
1.几何变换:
- 翻转:水平翻转(左右镜像)或垂直翻转图像。
- 旋转:以任意角度旋转图像。
- 平移:在水平或垂直方向上移动图像。
- 缩放:放大或缩小图像尺寸。
- 倾斜:沿某个轴倾斜图像。
- 透视变换:模拟从不同视角拍摄图像的效果。
2.像素级操作:
- 亮度调整:增加或减少图像的整体亮度。
- 对比度调整:改变图像中颜色或灰度的对比强度。
- 饱和度调整:改变图像的颜色饱和度。
- 色调调整:修改图像的色调。
- 高斯噪声:在图像上添加随机噪声点。
- 椒盐噪声:在图像上添加随机的黑点或白点。
- 模糊:使用高斯模糊或其他滤波器使图像变得模糊。
- 锐化:增强图像边缘,使图像看起来更清晰。
- 色彩抖动:随机改变图像的色彩属性,如亮度、对比度、饱和度和色调。
3.裁剪和填充:
- 随机裁剪:从图像中随机选取一个区域作为新图像。
- 中心裁剪:从图像中心裁剪固定大小的区域。
- 弹性变形:模拟图像在物理上的扭曲效果。
- 随机擦除:在图像上随机选取一个区域并用平均颜色或零填充。
- 填充:当需要保持图像大小一致时,对裁剪后的图像进行填充。
4.混合和组合:
- CutMix:将两个图像按随机形状和位置混合。
- MixUp:线性插值两个图像及其标签。
- GridMask:在图像上应用网格状的遮罩,去除部分信息。
- Mosaic:将多个图像拼接在一起形成一个新的图像。
数据增强库
有一些供开发人员使用的库,例如 Albumentations、Augmentor、Imgaug、nlpaug、NLTK 和 spaCy。这些库包括几何变换和色彩空间变换函数、内核过滤器(即用于锐化和模糊的图像处理函数)和其他文本变换。数据增强库使用不同的深度学习框架,例如 Keras、MxNet、PyTorch 和 TensorFlow。
Torchvision
torchvision
是PyTorch框架的一个子库,主要专注于计算机视觉任务。它提供了许多用于处理图像数据的功能,包括数据加载、预处理、数据增强、常用计算机视觉模型的实现以及一些标准的数据集。
torchvision
的主要组件包括:
-
Transforms:用于图像数据的预处理和数据增强。这包括转换图像大小、裁剪、翻转、颜色调整等操作。
-
Datasets:包含了多种常用的数据集,如CIFAR10、ImageNet、MNIST、COCO等,这些数据集可以直接通过
torchvision.datasets
下载和加载。 -
Models:实现了许多流行的卷积神经网络架构,如ResNet、VGG、AlexNet、Inception、MobileNet等。这些模型既可以用于直接预测,也可以作为预训练模型来微调或进行迁移学习。
-
Utils:提供了一些实用函数,如将Tensor转换为图像、可视化模型输出、保存和加载模型等。
torchvision
通常与torch
库结合使用,torch
提供了深度学习所需的张量计算和自动求导功能。使用torchvision
可以极大地简化构建和训练计算机视觉模型的过程。
torchvision用于数据增强的示例
import torch from torchvision import transforms # 定义数据增强转换 transform = transforms.Compose([ transforms.RandomResizedCrop(224), # 随机裁剪并调整到224x224 transforms.RandomHorizontalFlip(), # 随机水平翻转 transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1), # 颜色抖动 transforms.ToTensor(), # 转换为Tensor transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 归一化 ]) # 加载数据集 dataset = torchvision.datasets.ImageFolder(root='path/to/dataset', transform=transform)
定义了一个包含多个转换操作的
Compose
对象。RandomResizedCrop
和RandomHorizontalFlip
用于增加图像的空间多样性,ColorJitter
用于增加颜色空间的多样性,而ToTensor
和Normalize
则用于将图像转换为适合模型输入的格式。
注意!对于测试数据集,通常不应用数据增强,只进行必要的预处理,例如转换为Tensor和标准化。
Albumentation
albumentations 是一个用于图像增强的开源 Python 库,专门针对计算机视觉任务设计。它提供了快速和灵活的图像变换方法,被广泛应用于图像分类、目标检测、语义分割和其他基于深度学习的视觉任务中。albumentations 与其他图像增强相关软件包的区别在于,该软件包已通过多个基于 OpenCV 的库进行了优化。
用Albumentations替换上述torchvision的实现代码:
import albumentations as A from albumentations.pytorch import ToTensorV2 import torch from torch.utils.data import DataLoader from torchvision.datasets import ImageFolder # 定义数据增强转换 transform = A.Compose([ A.RandomResizedCrop(height=224, width=224), A.HorizontalFlip(), A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1), A.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], ), ToTensorV2(), ]) # 定义一个函数来适应 ImageFolder 的需求,因为 ImageFolder 需要返回一个 PIL 图像 def albumentations_transform(image): return transform(image=image)['image'] # 加载数据集 dataset = ImageFolder(root='path/to/dataset', transform=albumentations_transform) # 创建数据加载器 dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
在
Albumentations
中,数据增强是作为一个字典返回的,其中键 'image' 对应于转换后的图像。因此,我们需要定义一个辅助函数albumentations_transform
来适配ImageFolder
数据集的接口,该接口期望一个简单的转换函数作为输入。此外,
Albumentations
提供了ToTensorV2
变换,它会将图像转换为 PyTorch 的张量,并将其通道顺序从 HWC 转换为 CHW。
Baseline代码
train_loader = torch.utils.data.DataLoader( FFDIDataset(train_label['path'].head(1000), train_label['target'].head(1000), transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) ), batch_size=40, shuffle=True, num_workers=4, pin_memory=True ) val_loader = torch.utils.data.DataLoader( FFDIDataset(val_label['path'].head(1000), val_label['target'].head(1000), transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) ), batch_size=40, shuffle=False, num_workers=4, pin_memory=True )
图像大小调整:使用
transforms.Resize((256, 256))
将所有图像调整到256x256像素的尺寸,这有助于确保输入数据的一致性。随机水平翻转:
transforms.RandomHorizontalFlip()
随机地水平翻转图像,这种变换可以模拟物体在不同方向上的观察,从而增强模型的泛化能力。随机垂直翻转:
transforms.RandomVerticalFlip()
随机地垂直翻转图像,这同样是为了增加数据多样性,让模型能够学习到不同视角下的特征。转换为张量:
transforms.ToTensor()
将图像数据转换为PyTorch的Tensor格式,这是在深度学习中处理图像数据的常用格式。归一化:
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
对图像进行归一化处理,这里的均值和标准差是根据ImageNet数据集计算得出的,用于将图像像素值标准化,这有助于模型的训练稳定性和收敛速度。resources:https://datawhaler.feishu.cn/wiki/Ad0jwNK8Eis5XwksFZ7cCvb6nHh
用Albumentations替换Baseline中的torchvision的实现代码:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch
from torch.utils.data import DataLoader
# 假设 FFDIDataset 是一个自定义数据集类,这里需要定义一个辅助函数来应用 Albumentations 转换
def albumentations_transform(image, label):
transformed = transform(image=image)
return transformed['image'], label
# 定义训练集的增强
transform_train = A.Compose([
A.Resize(256, 256),
A.HorizontalFlip(),
A.VerticalFlip(),
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
ToTensorV2()
])
# 定义验证集的转换
transform_val = A.Compose([
A.Resize(256, 256),
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
ToTensorV2()
])
# 创建数据集实例,这里假设 FFDIDataset 是你已经定义好的数据集类
train_dataset = FFDIDataset(train_label['path'].head(1000), train_label['target'].head(1000))
val_dataset = FFDIDataset(val_label['path'].head(1000), val_label['target'].head(1000))
# 应用转换函数
train_dataset.transform = lambda x, y: albumentations_transform(x, y)
val_dataset.transform = lambda x, y: albumentations_transform(x, y)
# 创建数据加载器
train_loader = DataLoader(
train_dataset, batch_size=40, shuffle=True, num_workers=4, pin_memory=True
)
val_loader = DataLoader(
val_dataset, batch_size=40, shuffle=False, num_workers=4, pin_memory=True
)
Resources:[九月]Deepfake-FFDI-图像赛题 ch3 Modified by Hong (kaggle.com)