对图片进行数据增强(基于pytorch)

背景

在进行机器学习的任务中,我们的训练数据往往是有限的,在有限的数据集上获得较好的模型训练结果,我们不仅要在模型结构上下功夫,另一方面也需要对数据集进行数据增强


图片数据增强

图像数据增强是一种在训练机器学习和深度学习模型时常用的策略,尤其是在计算机视觉领域。具体而言,它通过创建和原始图像稍有不同的新图像来扩大训练集。 数据增强的主要目标有以下几点:

  •  解决过拟合:过拟合是指模型在训练集上表现得过于优秀,但是在测试集(即未见过的新数据)上表现差的现象。一个常见的解决过拟合的策略是增加训练数据。数据增强通过在原有训练数据的基础上增加各种变化的数据,有效地增大了训练集。
  • 提高模型的泛化能力:一些数据增强手段(如旋转、缩放、平移等)可以模拟一些真实场景中会产生的视觉变化,有助于训练模型对这些场景变化更具有鲁棒性,从而提高模型的泛化能力。
  • 引入可控制的噪声:一些数据增强方法,如随机裁剪、像素值噪声、颜色偏移等,可以在一定程度上模拟真实环境中的噪声。以这样的方式引入的噪声可以使模型更健壮,并且增强模型的噪声容忍力。
  • 视觉不变性:通过像翻转、旋转这样的变换,数据增强可以帮助模型在任何视觉角度下都能正确地识别出相同的对象,输入图像进行各种方式的扭曲后仍能被模型准确识别出来,增强了模型的视觉不变性。 总的来说,图片数据增强可以让模型学习到更多样性的数据,可以在一定程度上提升模型的识别准确率,更好的适应实际环境中样本的多样性,从而提高模型的泛化能力。

代码实现

我们使用torchvision的transforms库对图片数据进行数据增强,使用一张卡比巴拉的图片

首先读取图片数据,以下是准备工作

from PIL import Image
import numpy as np
import torchvision.transforms as tfs
import matplotlib.pyplot as plt

img_path = r"D:\CSDN_point\1_4\kabibala.jpg"
img = Image.open(img_path)
print("the shape of img is {}".format(np.array(img).shape))

图片伸缩

img_re = tfs.Resize((500,1000))(img)
plt.imshow(img_re)
plt.show()

tfs.Reszie((500,1000))把图像的高和宽分别拉伸到500像素和1000像素

图片裁剪

img_crop = tfs.RandomCrop(500)(img)
plt.imshow(img_crop)
plt.show()

tfs.RandomCrop(500)随机截取图片500\times500大小的区域

中心裁剪

img_crop_cen = tfs.CenterCrop(700)(img)
plt.imshow(img_crop_cen)
plt.show()

tfs.CenterCrop(700)裁剪图片中心位置700\times700大小的区域

随机水平翻转

# 随机水平翻转,概率是0.5
img_hori = tfs.RandomHorizontalFlip()(img)
# 随机垂直翻转,概率是0.5
img_ver = tfs.RandomVerticalFlip()(img)

plt.subplot(1,2,1)
plt.imshow(img_hori)
plt.title("RandomHorizontalFlip")
plt.subplot(1,2,2)
plt.imshow(img_ver)
plt.title("RandomVerticalFlip")
plt.show()

随机改变图片亮度、对比度和色相

img_j = tfs.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5)(img)
plt.imshow(img_j)
plt.show()

tfs.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5)中参数的含义:

  1. brightness=0.5:亮度的浮点数系数,brightness系数在 [max(0, 1 - brightness), 1 + brightness] 的范围内随机选择。例如,brightness=0.5,就表明亮度在 [0.5, 1.5] 的范围内随机选择。

  2. contrast=0.5:对比度的浮点数系数。对比度系数在 [max(0, 1 - contrast), 1 + contrast] 的范围内随机选择。例如,contrast=0.5,就表明对比度在 [0.5, 1.5] 的范围内随机选择。

  3. hue=0.5:色相的浮点数系数。色相系数在 [-hue, hue] 的范围内随机选择。例如,hue=0.5,就表明色相在 [-0.5, 0.5] 的范围内随机选择。

多变换整合

可以使用Compose模块将多个变换进行整合

# 将多个变换整合到my_trans
my_trans = tfs.Compose([tfs.Resize(800),tfs.RandomHorizontalFlip(),tfs.RandomCrop(500),tfs.RandomVerticalFlip(),tfs.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5)])

_,fig = plt.subplots(3,3)
for i in range(3):
    for j in range(3):
        fig[i][j].imshow(my_trans(img))
        plt.axis('off')
plt.show()

欢迎大家讨论交流~


  • 53
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
### 回答1: 以下是一个基于PyTorch数据增强代码示例,使用了torchvision.transforms模块中的一些预定义的转换函数。 ```python import torchvision.transforms as transforms # 定义一个数据增强的转换函数 transform = transforms.Compose([ transforms.RandomHorizontalFlip(), # 随机水平翻转 transforms.RandomRotation(10), # 随机旋转10度以内 transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), # 随机调整亮度、对比度、饱和度和色调 transforms.RandomCrop(32, padding=4), # 随机裁剪,大小为32,四周填充4个像素 transforms.ToTensor(), # 转换为tensor transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化 ]) # 使用该转换函数对数据集进行增强 augmented_dataset = dataset.transform(transform) ``` 这个例子展示了一些简单但常用的数据增强方法,但实际上还有很多其他的方法可以使用,具体取决于你的应用场景和需要。另外,还可以将这些转换函数组合成不同的变换序列,以创建更复杂的增强方法。 ### 回答2: import torchvision.transforms as transforms from PIL import Image import torch # 定义自定义的数据增强函数 def data_augmentation(image_path): # 加载原始图像 image = Image.open(image_path) # 数据增强 transform = transforms.Compose([ transforms.RandomHorizontalFlip(), # 随机水平翻转 transforms.RandomVerticalFlip(), # 随机垂直翻转 transforms.RandomRotation(30), # 随机旋转 transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2), # 随机改变亮度、对比度、饱和度和色调 transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)), # 随机裁剪并调整大小 transforms.ToTensor(), # 转为张量 ]) augmented_image = transform(image) return augmented_image # 示例用法 image_path = 'example.jpg' augmented_image = data_augmentation(image_path) print(augmented_image.shape) # 输出:torch.Size([3, 256, 256]) # 将张量可视化为图像 normalized_image = augmented_image.permute(1, 2, 0).numpy() # 调整维度顺序并转为numpy数组 Image.fromarray((normalized_image * 255).astype('uint8')).show() # 转为PIL图像并显示 ### 回答3: 下面是一个基于PyTorch数据增强的示例代码: ```python import torch import torchvision.transforms as transforms # 创建一个数据集 dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True) # 定义数据增强的变换 transform = transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # 创建一个数据加载器 dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True) # 对数据进行增强并展示 for images, labels in dataloader: augmented_images = transform(images) # 进行其他操作,例如训练模型等 ``` 该代码使用了PyTorch提供的`torchvision.transforms`模块中的一些常用数据增强操作,包括随机水平翻转(`RandomHorizontalFlip`)、随机垂直翻转(`RandomVerticalFlip`)、随机裁剪(`RandomCrop`)、将PIL图像转换为张量(`ToTensor`)和归一化处理(`Normalize`)。可以根据具体需求自行选择和组合这些操作。 在上述示例代码中,我们使用CIFAR-10数据集作为示例。首先,创建了一个数据集对象,然后定义了一个数据增强的变换对象`transform`。接着,创建了一个数据加载器`dataloader`,并通过循环遍历数据加载器,对每个数据进行增强操作并展示。 需要注意的是,上述代码仅仅是一个示例,实际应用中可能还需要进行其他相关处理,例如准备训练集和测试集、设置批次大小和迭代次数等。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值