目录
在深度学习中,数据增强(Data Augmentation)是一种提高模型泛化能力的技术,通过从现有数据集中生成新的训练样本来增加数据多样性。数据增强尤其对于图像分类、目标检测和语义分割等计算机视觉任务非常有用,因为这些任务通常需要大量的标注数据。本文将以pytorch中的torchvision.transforms为例介绍常见的数据增强技术。
1.ToTensor:
transforms.ToTensor()
# 将 PIL Image 或 Numpy 数组转换为 torch.FloatTensor,并将数值范围从 [0, 255] 缩放到 [0.0, 1.0]
2.Resize:
transforms.Resize([80,80])
# 对图像进行缩放
# 参数:size(一个整数或元组,指定新的尺寸)
3.RandomRotation:
transforms.RandomRotation(45)
# 随机旋转,-45到45度之间随机
4.CenterCrop:
transforms.CenterCrop(64)
# 将图像中心裁剪到指定尺寸
# 参数:size(裁剪后的尺寸)
5.RandomHorizontalFlip:
transforms.RandomHorizontalFlip(p=0.5)
# 随机水平翻转图像
# 参数:p(翻转的概率)
6.RandomVerticalFlip:
transforms.RandomVerticalFlip(p=0.5)
# 随机垂直翻转图像
# 参数:p(翻转的概率)
7.ColorJitter:
transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1)
# 随机改变图像的亮度、对比度、饱和度和色调
# 参数1为亮度,参数2为对比度,参数3为饱和度,参数4为变化的范围
8.RandomGrayscale:
transforms.RandomGrayscale(p=0.1)
# 随机将图像转换为灰度图
# 参数:p(转换为灰度的概率)
9.Normalize:
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# 标准化图像的像素值,通常用于训练过程中
# 参数:mean(各通道的均值),std(各通道的标准差)
10.Compose:
transforms.Compose
# 组合多个变换
# 参数:transforms(变换列表)
11.代码示例
import torchvision.transforms as transforms
# 定义数据增强操作
transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5), # 50%概率水平翻转
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # 随机颜色变换
transforms.ToTensor(), # 转换为Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 标准化
])
# 应用数据增强
# img是一个PIL Image
img_transformed = transform(img)