在pytorch中常见的数据预处理方法

PyTorch 的 torchvision.transforms 提供了一系列常见的图像预处理方法,适用于 图像分类、检测、分割 等任务。

1. 基本的图像预处理:

from torchvision import transforms

basic_transform = transforms.Compose([
    transforms.Resize((128, 128)),  # 调整大小为 128x128
    transforms.ToTensor(),          # 转为 Tensor,像素归一化为 [0, 1]
])

​ 效果:

​ 将图像调整为 128x128 像素。

​ 将像素值从 [0, 255] 转换为 [0.0, 1.0]。

2.数据增强:

augment_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),   # 随机水平翻转
    transforms.RandomVerticalFlip(p=0.2),     # 随机垂直翻转
    transforms.RandomRotation(30),            # 随机旋转 ±30°
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),  # 随机亮度、对比度、饱和度
    transforms.ToTensor(),
])

​ 效果:

​ 图像会随机水平或垂直翻转。

​ 图像会随机旋转 -30° 到 +30°。

​ 图像颜色会随机变化,增强鲁棒性。

3. 归一化(标准化):

normalize_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

​ 效果:

​ 将图像通道归一化为:

channel = channel − mean std \text{channel} = \frac{\text{channel} - \text{mean}}{\text{std}} channel=stdchannelmean

​ 对于 RGB 图像:

​ Red 通道:(R - 0.5) / 0.5

​ Green 通道:(G - 0.5) / 0.5

​ Blue 通道:(B - 0.5) / 0.5

4. 灰度化(黑白图像):

gray_transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # 转为灰度图(1通道)
    transforms.ToTensor()
])

​ 效果:

​ 将彩色图像转换为单通道灰度图。

5. 随机裁剪(随机区域):

random_crop_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),  # 随机裁剪到 80% - 100% 大小
    transforms.ToTensor()
])

​ 效果:

​ 在原图上随机裁剪一个区域,并调整为 224x224。

6. 中心裁剪(固定区域):

center_crop_transform = transforms.Compose([
    transforms.CenterCrop(224),  # 中心裁剪到 224x224
    transforms.ToTensor()
])

​ 效果:

​ 以中心为基准裁剪图像。

7. 自动数据增强(AutoAugment):

from torchvision.transforms import AutoAugment, AutoAugmentPolicy

autoaugment_transform = transforms.Compose([
    AutoAugment(AutoAugmentPolicy.CIFAR10),  # 针对 CIFAR-10 数据集的自动增强
    transforms.ToTensor()
])

​ 效果:

​ 自动选择一系列增强策略(翻转、旋转、色彩调整等),提高泛化性。

8. 组合多个预处理:

full_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

​ 效果:

​ 将图像调整为 256x256,随机增强并标准化。

9. 使用这些预处理在 DataLoader 中:

from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

train_data = CIFAR10(root='./data', train=True, transform=full_transform, download=True)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)

for imgs, labels in train_loader:
    print(imgs.shape, labels.shape)
    break
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值