深度学习之 数据增强 【附代码】

数据增强


前面我们已经讲了几个非常著名的卷积网络的结构,但是单单只靠这些网络并不能取得很好的结果,现实问题往往更加复杂,非常容易出现过拟合的问题,而数据增强的方法是对抗过拟合问题的一个重要方法。

2012 年 AlexNet 在 ImageNet 上大获全胜,图片增强方法功不可没,因为有了图片增强,使得训练的数据集比实际数据集多了很多’新’样本,减少了过拟合的问题,下面我们来具体解释一下。

常用的数据增强方法

常用的数据增强方法如下:

  1. 对图片进行一定比例缩放
  2. 对图片进行随机位置的截取
  3. 对图片进行随机的水平和竖直翻转
  4. 对图片进行随机角度的旋转
  5. 对图片进行亮度、对比度和颜色的随机变化

这些方法 pytorch 都已经为我们内置在了 torchvision 里面,我们在安装 pytorch 的时候也安装了 torchvision,下面我们来依次展示一下这些数据增强方法

import sys
sys.path.append('..')

from PIL import Image
from torchvision import transforms as tfs
# 读入一张图片
im = Image.open('图片/cat.jpg')
im

在这里插入图片描述

sys.path 是一个列表 list ,它里面包含了 已经添加到系统的环境变量 路径。
当我们要添加自己的引用模块搜索目录时,可以通过列表 list 的 append()方法;

对于需要引用的模块和需要执行的脚本文件不在同一个目录时,可以按照如下形式来添加路径:

import sys  
sys.path.append(’需要引用模块的地址')  
# sys.path.append(..)   # 这代表添加当前路径的上一级目录

1.1 随机比例缩放主要使用的是

随机比例缩放主要使用的是 torchvision.transforms.Resize()
这个函数,第一个参数可以是一个整数,那么图片会保存现在的宽和高的比例,并将更短的边缩放到这个整数的大小。

第一个参数也可以是一个 tuple,那么图片会直接把宽和高缩放到这个大小;

第二个参数表示放缩图片使用的方法,比如最邻近法,或者双线性差值等,一般双线性差值能够保留图片更多的信息,所以 pytorch 默认使用的是双线性差值,

# 比例缩放
print('before scale, shape: {}'.format(im.size))
new_im = tfs.Resize((100, 200))(im)
print('after scale, shape: {}'.format(new_im.size))
new_im
before scale, shape: (710, 1000)
after scale, shape: (200, 100)

在这里插入图片描述

1.2随机位置截取

随机位置截取能够提取出图片中局部的信息,使得网络接受的输入具有多尺度的特征,所以能够有较好的效果。在 torchvision中主要有下面两种方式,一个是 torchvision.transforms.RandomCrop(),传入的参数就是截取出的图片的长和宽,对图片在随机位置进行截取;第二个是 torchvision.transforms.CenterCrop(),同样传入截取
的图片的大小作为参数,会在图片的中心进行截取

# 随机裁剪出 500 x 500 的区域
random_im1 = tfs.RandomCrop((500,500))(im)
random_im1

在这里插入图片描述

# 中心裁剪出 300 x 300 的区域
center_im = tfs.CenterCrop(300)(im)
center_im

在这里插入图片描述

1.3 随机的水平和竖直方向翻转

对于上面这一张猫的图片,如果我们将它翻转一下,它仍然是一张猫,但是图片就有了更多的多样性,所以随机翻转也是一种非常有效的手段。在 torchvision 中,随机翻转使用的是 torchvision.transforms.RandomHorizontalFlip()torchvision.transforms.RandomVerticalFlip()

# 随机水平翻转
h_filp = tfs.RandomHorizontalFlip()(im)
h_filp

在这里插入图片描述

# 随机竖直翻转
v_flip = tfs.RandomVerticalFlip()(im)
v_flip

在这里插入图片描述

1.3 随机角度旋转

一些角度的旋转仍然是非常有用的数据增强方式,在 torchvision 中,使用 torchvision.transforms.RandomRotation() 来实现,其中第一个参数就是随机旋转的角度,比如填入 10,那么每次图片就会在 -10 ~ 10 度之间随机旋转

rot_im = tfs.RandomRotation(45)(im)
rot_im

在这里插入图片描述

1.4 亮度、对比度和颜色的变化

除了形状变化外,颜色变化又是另外一种增强方式,其中可以设置亮度变化,对比度变化和颜色变化等,在 torchvision 中主要使用 torchvision.transforms.ColorJitter() 来实现的,第一个参数就是亮度的比例,第二个是对比度,第三个是饱和度,第四个是颜色

# 亮度
bright_im = tfs.ColorJitter(brightness=1)(im) 
bright_im

在这里插入图片描述

# 对比度
contrast_im = tfs.ColorJitter(contrast=10)(im) 
contrast_im

在这里插入图片描述

# 颜色
color_im = tfs.ColorJitter(hue=0.5)(im) # 随机从 -0.5 ~ 0.5 之间对颜色变化
color_im

在这里插入图片描述

上面我们讲了这么图片增强的方法,其实这些方法都不是孤立起来用的,可以联合起来用,比如先做随机翻转,然后随机截取,再做对比度增强等等,torchvision 里面有个非常方便的函数能够将这些变化合起来,就是 torchvision.transforms.Compose(),下面我们举个例子

im_aug = tfs.Compose([
    tfs.Resize(120),
    tfs.RandomHorizontalFlip(),
    tfs.RandomCrop(96),
    tfs.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5)
])
import matplotlib.pyplot as plt
%matplotlib inline
nrows = 3
ncols = 3
figsize = (8, 8)
_, figs = plt.subplots(nrows, ncols, figsize=figsize)# 创建子图画布
for i in range(nrows):
    for j in range(ncols):
        figs[i][j].imshow(im_aug(im))
        figs[i][j].axes.get_xaxis().set_visible(False)
        figs[i][j].axes.get_yaxis().set_visible(False)
plt.show()

在这里插入图片描述
参考:数据增强

  • 5
    点赞
  • 51
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
深度学习数据增强是一种常用的技术,它可以通过对原始数据进行变换、旋转、缩放等操作来生成更多的训练数据,从而提高模型的泛化能力和鲁棒性。 以下是一个基于Keras实现的简单数据增强代码示例: ```python from keras.preprocessing.image import ImageDataGenerator # 创建ImageDataGenerator对象 datagen = ImageDataGenerator( rotation_range=20, # 旋转范围 width_shift_range=0.1, # 水平平移范围 height_shift_range=0.1, # 垂直平移范围 shear_range=0.2, # 错切变换范围 zoom_range=0.2, # 缩放范围 horizontal_flip=True, # 是否进行水平翻转 fill_mode='nearest' # 填充方式 ) # 加载原始数据 train_data = ... train_labels = ... # 对原始数据进行增强 datagen.fit(train_data) augmented_data = datagen.flow(train_data, train_labels, batch_size=batch_size) # 使用增强后的数据进行训练 model.fit(augmented_data, epochs=epochs, steps_per_epoch=steps_per_epoch) ``` 在上面的代码中,我们使用Keras提供的`ImageDataGenerator`类创建了一个数据增强对象,并设置了旋转、平移、错切、缩放、翻转等操作的范围。然后,我们将原始数据传入`datagen.fit()`方法中进行增强,生成增强后的数据集`augmented_data`。最后,我们使用增强后的数据集进行模型训练。 当然,上面的代码只是一个简单的示例,实际应用中可能需要根据具体任务的需求进行相应的参数设置和数据增强操作。
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值