Deepwhale AI夏令营 数据增强

数据增强

目的:增加训练数据的多样性,从而提高模型的泛化能力,使其能够在未见过的数据上表现得更好

具体来说,变换原始数据,生成新的训练样本,模拟真实世界中的变化。对于图像而言,数据增强包括例如视角、光照、遮挡等情况,使得模型能够学习到更加鲁棒的(robust)特征表示。

举例:PyTorch, Mixup, Cutmix

注意:不可随意过度变换,使其与目标场景不符,否则引入噪音,图像湿疹,模型难以有效学习

PyTorch框架

train_loader = torch.utils.data.DataLoader(
    FFDIDataset(train_label['path'].head(1000), train_label['target'].head(1000), 
            transforms.Compose([
                        transforms.Resize((256, 256)),
                        transforms.RandomHorizontalFlip(),
                        transforms.RandomVerticalFlip(),
                        transforms.ToTensor(),
                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    ), batch_size=40, shuffle=True, num_workers=4, pin_memory=True
)

使用 PyTorch 创建一个数据加载器 (`DataLoader`) 来加载图像数据集并进行一些预处理和数据增强操作。

1. 数据集定义

FFDIDataset(…

        ])

)

假设 `FFDIDataset` 是一个自定义数据集类,继承自 `torch.utils.data.Dataset`。它的构造函数接受图像路径和标签,以及一系

`train_label['path']`:一个包含所有训练图像文件路径列表

`train_label['target']`:一个包含所有训练图像对应标签列表

`transforms.Compose`用于将一系列图像变换操作组合在一起

`transforms.RandomHorizontalFlip()`:以 50% 的概率随机水平翻转图像。

`transforms.RandomVerticalFlip()`:以 50% 的概率随机垂直翻转图像。

`transforms.ToTensor()`:将图像转换为 PyTorch 的张量(tensor)格式,并且将像素值从 [0, 255] 缩放到 [0.0, 1.0]。

`transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])`:使用给定的均值和标准差对图像进行归一化,这通常是用预训练模型(如 ImageNet 上训练的模型)时的标准操作。

红色、绿色、蓝色通道的均值分别为 0.485,0.456,0.406

红色、绿色、蓝色通道的标准值分别为 0.229, 0.224, 0.225

这些标准差也是从 ImageNet 数据集上计算得到的。

先将图像转换为张量缩放到 [0.0, 1.0],进一步基于每个通道均值和标准差进行归一化,以确保输入数据符合预训练模型(如在ImageNet上训练的模型)的期望数据分布。

2. 数据加载器定义

train_loader = torch.utils.data.DataLoader(

    dataset=FFDIDataset(train_label['path'], train_label['target'], transforms.Compose([…

        ])

    ),

    …

)

`torch.utils.data.DataLoader`

`DataLoader`是PyTorch 中用于批量加载数据的工具。它接受以下参数:

`dataset`:一个继承自 `torch.utils.data.Dataset` 的数据集对象,在这里是 `FFDIDataset`。

`batch_size=40`:每个批次加载 40 个样本(图像)

`shuffle=True`:在每个 epoch 开始时对数据进行随机打乱。

`num_workers=4`:使用 4 个子进程来加载数据,可以加快数据加载速度。

`pin_memory=True`:如果设置为 True,数据加载器会在将数据转移到 GPU 之前将数据保存在固定内存中,这样可以加快数据传输速度。

Torchvision

1. 几何变换

Resize调整大小

RandomCrop/RandomResizedCrop随机裁剪

CenterCrop中心裁剪

FiveCrop/TenCrop裁剪四个角/中心区

RandomHorizontalFlip/RandomVerticalFlip水平/垂直翻转

RandomRotation随机旋转

RandomAffine随机仿射变换。

RandomPerspective随机透视变换

举例:

import torchvision.transforms as transforms
from PIL import Image

image = Image.open('image.jpg')
resize_transform = transforms.Resize((256, 256))
resized_image = resize_transform(image)

random_crop_transform = transforms.RandomCrop(size=(200, 200))
cropped_image = random_crop_transform(image)

2. 颜色变换

ColorJitter随机改变图像的亮度、对比度、饱和度和色调

Grayscale/RandomGrayscale灰度化

GaussianBlur高斯模糊

RandomInvert随机反转颜色

RandomPosterize颜色 posterize减少每个颜色通道的位数

RandomSolarize颜色 solarize反转高于阈值的像素值

举例:

import torchvision.transforms as transforms
from PIL import Image

image = Image.open('image.jpg')
color_jitter_transform = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.3)
jittered_image = color_jitter_transform(image)

gray_transform = transforms.Grayscale(num_output_channels=1)
gray_image = gray_transform(image)

gaussian_blur_transform = transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 2.0))
blurred_image = gaussian_blur_transform(image)

random_invert_transform = transforms.RandomInvert(p=1.0)
inverted_image = random_invert_transform(image)

posterize_transform = transforms.RandomPosterize(bits=4)  #每个颜色通道减少到16个颜色
posterized_image = posterize_transform(image)

solarize_transform = transforms.RandomSolarize(threshold=128)
solarized_image = solarize_transform(image)

3. 自动增强

AutoAugment自动学习数据增强策略

RandAugment可以随机应用一系列数据增强操作

TrivialAugmentWide提供与数据集无关的数据增强AugMix

AugMix通过混合多个增强操作进行数据增强

举例:

import torchvision.transforms as transforms
from PIL import Image

image = Image.open('image.jpg')
auto_augment_transform = transforms.AutoAugment()
augmented_image = auto_augment_transform(image)

rand_aug_transform = transforms.RandAugment(num_ops=2, magnitude=9)  
# num_ops: 随机操作次数 magnitude: 强度
augmented_image = rand_aug_transform(image)

trivial_augment_transform = transforms.TrivialAugmentWide()
augmented_image = trivial_augment_transform(image)

augmix_transform = transforms.AugMix()
augment_image = augmix_transform(image)

  • 9
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值