PyTorch 的 torchvision.transforms 提供了一系列常见的图像预处理方法,适用于 图像分类、检测、分割 等任务。
1. 基本的图像预处理:
from torchvision import transforms
basic_transform = transforms.Compose([
transforms.Resize((128, 128)), # 调整大小为 128x128
transforms.ToTensor(), # 转为 Tensor,像素归一化为 [0, 1]
])
效果:
将图像调整为 128x128 像素。
将像素值从 [0, 255] 转换为 [0.0, 1.0]。
2.数据增强:
augment_transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转
transforms.RandomVerticalFlip(p=0.2), # 随机垂直翻转
transforms.RandomRotation(30), # 随机旋转 ±30°
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5), # 随机亮度、对比度、饱和度
transforms.ToTensor(),
])
效果:
图像会随机水平或垂直翻转。
图像会随机旋转 -30° 到 +30°。
图像颜色会随机变化,增强鲁棒性。
3. 归一化(标准化):
normalize_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
效果:
将图像通道归一化为:
channel = channel − mean std \text{channel} = \frac{\text{channel} - \text{mean}}{\text{std}} channel=stdchannel−mean
对于 RGB 图像:
Red 通道:(R - 0.5) / 0.5
Green 通道:(G - 0.5) / 0.5
Blue 通道:(B - 0.5) / 0.5
4. 灰度化(黑白图像):
gray_transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1), # 转为灰度图(1通道)
transforms.ToTensor()
])
效果:
将彩色图像转换为单通道灰度图。
5. 随机裁剪(随机区域):
random_crop_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), # 随机裁剪到 80% - 100% 大小
transforms.ToTensor()
])
效果:
在原图上随机裁剪一个区域,并调整为 224x224。
6. 中心裁剪(固定区域):
center_crop_transform = transforms.Compose([
transforms.CenterCrop(224), # 中心裁剪到 224x224
transforms.ToTensor()
])
效果:
以中心为基准裁剪图像。
7. 自动数据增强(AutoAugment):
from torchvision.transforms import AutoAugment, AutoAugmentPolicy
autoaugment_transform = transforms.Compose([
AutoAugment(AutoAugmentPolicy.CIFAR10), # 针对 CIFAR-10 数据集的自动增强
transforms.ToTensor()
])
效果:
自动选择一系列增强策略(翻转、旋转、色彩调整等),提高泛化性。
8. 组合多个预处理:
full_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
transforms.RandomGrayscale(p=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
效果:
将图像调整为 256x256,随机增强并标准化。
9. 使用这些预处理在 DataLoader 中:
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
train_data = CIFAR10(root='./data', train=True, transform=full_transform, download=True)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
for imgs, labels in train_loader:
print(imgs.shape, labels.shape)
break