数据增广
数据增强:在已有的数据集上,对数据进行变换,使得具有更大的多样性
- 在语言里面加入各种不同的噪音
- 增加不同的颜色和形状
做数据增强一般是随机增强,把许多方法随机作用在数据上:
常见数据增强方法如下:
QA:
- num_workers 根据GPU的性能来决定
- cifar10如果要做到95精度,差不多需要200个epoch
- 训练精度下降,还可以继续训练
- mosaic增广方式主要是添加一种遮挡,mixup(多张数据叠加)也是一种数据增强方式
- mixup是一种有效但无法解释的增广方法
%matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2l
d2l.set_figsize()
img = d2l.Image.open('/content/drive/MyDrive/深度学习/OIP.jpg')
d2l.plt.imshow(img)
def apply(img,aug,num_rows=2,num_cols=4,scale=1.5):
y = [aug(img) for _ in range(num_rows * num_cols)]
d2l.show_images(y,num_rows,num_cols,scale=scale)
左右翻转
apply(img,torchvision.transforms.RandomHorizontalFlip())
apply(img,torchvision.transforms.RandomVerticalFlip())
随机裁剪
shape_aug = torchvision.transforms.RandomResizedCrop(
(200,200),scale=(0.1,1),ratio=(0.5,2)
)
# 最后的输出是(200*200),scale就是指裁取原始图像的比例从0.1-1,ratio指裁取的高宽比
apply(img,shape_aug)
随机改变图像亮度
apply(img,torchvision.transforms.ColorJitter(
brightness=0.5,contrast=0,saturation=0,hue=0
))
# brightness 亮度上下改变0.5之间,contrast 对比度,saturation 饱和度,hue 色调
随机改变 亮度、对比度、饱和度、色调
apply(img,torchvision.transforms.ColorJitter(
brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5
))
augs = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ColorJitter(
brightness=0.5,contrast=0.5,saturation=0.5,hue=0.5),
torchvision.transforms.RandomResizedCrop(
(200,200),scale=(0.1,1),ratio=(0.5,2))
])
apply(img,augs)
使用图像增广进行训练
all_images = torchvision.datasets.CIFAR10(
train=True,root="../data",download=True
)
d2l.show_images([
all_images[i][0] for i in range(32)],4,8,scale=0.8)
只使用简单的随机左右翻转
trian_augs = torchvision.transforms.Compose(
[
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor()
#ToTensor就是说把图片变成一个四维张量,一般做图片增广都这样操作
]
)
test_augs = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor()
]
)
定义辅助函数,便于读取图像和应用图像增广
def load_cifar10(is_train,augs,batch_size):
dataset = torchvision.datasets.CIFAR10(
root = "../data",train=is_train,
transform=augs,download=True,
)
dataloader = torch.utils.data.DataLoader(
dataset,batch_size = batch_size,shuffle=is_train,
num_workers=d2l.get_dataloader_workers()
# 如果做了数据增广,可以把num_workers定义大一点
)
return dataloader
…