动手学PyTorch(李沐)21 ---- 数据增广

数据增广

数据增强:在已有的数据集上,对数据进行变换,使得具有更大的多样性

  • 在语言里面加入各种不同的噪音
  • 增加不同的颜色和形状

做数据增强一般是随机增强,把许多方法随机作用在数据上:

常见数据增强方法如下:

QA:

  1. num_workers 根据GPU的性能来决定
  2. cifar10如果要做到95精度,差不多需要200个epoch
  3. 训练精度下降,还可以继续训练
  4. mosaic增广方式主要是添加一种遮挡,mixup(多张数据叠加)也是一种数据增强方式
  5. mixup是一种有效但无法解释的增广方法

%matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

d2l.set_figsize()
img = d2l.Image.open('/content/drive/MyDrive/深度学习/OIP.jpg')
d2l.plt.imshow(img)
def apply(img,aug,num_rows=2,num_cols=4,scale=1.5):
  y = [aug(img) for _ in range(num_rows * num_cols)]
  d2l.show_images(y,num_rows,num_cols,scale=scale)

左右翻转

apply(img,torchvision.transforms.RandomHorizontalFlip())
上下翻转
apply(img,torchvision.transforms.RandomVerticalFlip())

随机裁剪

shape_aug = torchvision.transforms.RandomResizedCrop(
    (200,200),scale=(0.1,1),ratio=(0.5,2)
)
# 最后的输出是(200*200),scale就是指裁取原始图像的比例从0.1-1,ratio指裁取的高宽比
apply(img,shape_aug)

随机改变图像亮度

apply(img,torchvision.transforms.ColorJitter(
    brightness=0.5,contrast=0,saturation=0,hue=0
))
# brightness 亮度上下改变0.5之间,contrast 对比度,saturation 饱和度,hue 色调

随机改变 亮度、对比度、饱和度、色调

apply(img,torchvision.transforms.ColorJitter(
    brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5
))
结合多种图像增广的方法
augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ColorJitter(
    brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5),
    torchvision.transforms.RandomResizedCrop(
    (200,200),scale=(0.1,1),ratio=(0.5,2))
])
apply(img,augs)

使用图像增广进行训练

all_images = torchvision.datasets.CIFAR10(
    train=True,root="../data",download=True
)
d2l.show_images([
    all_images[i][0] for i in range(32)],4,8,scale=0.8)

只使用简单的随机左右翻转

trian_augs = torchvision.transforms.Compose(
    [
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor()
        #ToTensor就是说把图片变成一个四维张量,一般做图片增广都这样操作
    ]
)

test_augs = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor()
    ]
)

定义辅助函数,便于读取图像和应用图像增广

def load_cifar10(is_train,augs,batch_size):
  dataset = torchvision.datasets.CIFAR10(
      root = "../data",train=is_train,
      transform=augs,download=True,
  )
  dataloader = torch.utils.data.DataLoader(
      dataset,batch_size = batch_size,shuffle=is_train,
      num_workers=d2l.get_dataloader_workers()
      # 如果做了数据增广,可以把num_workers定义大一点
  )
  return dataloader

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值