数据增强
目的:增加训练数据的多样性,从而提高模型的泛化能力,使其能够在未见过的数据上表现得更好
具体来说,变换原始数据,生成新的训练样本,模拟真实世界中的变化。对于图像而言,数据增强包括例如视角、光照、遮挡等情况,使得模型能够学习到更加鲁棒的(robust)特征表示。
举例:PyTorch, Mixup, Cutmix
注意:不可随意过度变换,使其与目标场景不符,否则引入噪音,图像湿疹,模型难以有效学习
PyTorch框架
train_loader = torch.utils.data.DataLoader(
FFDIDataset(train_label['path'].head(1000), train_label['target'].head(1000),
transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
), batch_size=40, shuffle=True, num_workers=4, pin_memory=True
)
使用 PyTorch 创建一个数据加载器 (`DataLoader`) 来加载图像数据集并进行一些预处理和数据增强操作。
1. 数据集定义
FFDIDataset(…
])
)
假设 `FFDIDataset` 是一个自定义数据集类,继承自 `torch.utils.data.Dataset`。它的构造函数接受图像路径和标签,以及一系
`train_label['path']`:一个包含所有训练图像文件路径的列表。
`train_label['target']`:一个包含所有训练图像对应标签的列表。
`transforms.Compose`用于将一系列图像变换操作组合在一起。
`transforms.RandomHorizontalFlip()`:以 50% 的概率随机水平翻转图像。
`transforms.RandomVerticalFlip()`:以 50% 的概率随机垂直翻转图像。
`transforms.ToTensor()`:将图像转换为 PyTorch 的张量(tensor)格式,并且将像素值从 [0, 255] 缩放到 [0.0, 1.0]。
`transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])`:使用给定的均值和标准差对图像进行归一化,这通常是用预训练模型(如 ImageNet 上训练的模型)时的标准操作。
红色、绿色、蓝色通道的均值分别为 0.485,0.456,0.406
红色、绿色、蓝色通道的标准值分别为 0.229, 0.224, 0.225
这些标准差也是从 ImageNet 数据集上计算得到的。
先将图像转换为张量并缩放到 [0.0, 1.0],进一步基于每个通道的均值和标准差进行归一化,以确保输入数据符合预训练模型(如在ImageNet上训练的模型)的期望数据分布。
2. 数据加载器定义
train_loader = torch.utils.data.DataLoader(
dataset=FFDIDataset(train_label['path'], train_label['target'], transforms.Compose([…
])
),
…
)
`torch.utils.data.DataLoader`
`DataLoader`是PyTorch 中用于批量加载数据的工具。它接受以下参数:
`dataset`:一个继承自 `torch.utils.data.Dataset` 的数据集对象,在这里是 `FFDIDataset`。
`batch_size=40`:每个批次加载 40 个样本(图像)。
`shuffle=True`:在每个 epoch 开始时对数据进行随机打乱。
`num_workers=4`:使用 4 个子进程来加载数据,可以加快数据加载速度。
`pin_memory=True`:如果设置为 True,数据加载器会在将数据转移到 GPU 之前将数据保存在固定内存中,这样可以加快数据传输速度。
Torchvision
1. 几何变换
Resize调整大小
RandomCrop/RandomResizedCrop随机裁剪
CenterCrop中心裁剪
FiveCrop/TenCrop裁剪四个角/中心区
RandomHorizontalFlip/RandomVerticalFlip水平/垂直翻转
RandomRotation随机旋转
RandomAffine随机仿射变换。
RandomPerspective随机透视变换
举例:
import torchvision.transforms as transforms
from PIL import Image
image = Image.open('image.jpg')
resize_transform = transforms.Resize((256, 256))
resized_image = resize_transform(image)
random_crop_transform = transforms.RandomCrop(size=(200, 200))
cropped_image = random_crop_transform(image)
2. 颜色变换
ColorJitter随机改变图像的亮度、对比度、饱和度和色调
Grayscale/RandomGrayscale灰度化
GaussianBlur高斯模糊
RandomInvert随机反转颜色
RandomPosterize颜色 posterize减少每个颜色通道的位数
RandomSolarize颜色 solarize反转高于阈值的像素值
举例:
import torchvision.transforms as transforms
from PIL import Image
image = Image.open('image.jpg')
color_jitter_transform = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.3)
jittered_image = color_jitter_transform(image)
gray_transform = transforms.Grayscale(num_output_channels=1)
gray_image = gray_transform(image)
gaussian_blur_transform = transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 2.0))
blurred_image = gaussian_blur_transform(image)
random_invert_transform = transforms.RandomInvert(p=1.0)
inverted_image = random_invert_transform(image)
posterize_transform = transforms.RandomPosterize(bits=4) #每个颜色通道减少到16个颜色
posterized_image = posterize_transform(image)
solarize_transform = transforms.RandomSolarize(threshold=128)
solarized_image = solarize_transform(image)
3. 自动增强
AutoAugment自动学习数据增强策略
RandAugment可以随机应用一系列数据增强操作
TrivialAugmentWide提供与数据集无关的数据增强AugMix
AugMix通过混合多个增强操作进行数据增强
举例:
import torchvision.transforms as transforms
from PIL import Image
image = Image.open('image.jpg')
auto_augment_transform = transforms.AutoAugment()
augmented_image = auto_augment_transform(image)
rand_aug_transform = transforms.RandAugment(num_ops=2, magnitude=9)
# num_ops: 随机操作次数 magnitude: 强度
augmented_image = rand_aug_transform(image)
trivial_augment_transform = transforms.TrivialAugmentWide()
augmented_image = trivial_augment_transform(image)
augmix_transform = transforms.AugMix()
augment_image = augmix_transform(image)