1. 数据增强
- 人工智能的本质:
- 统计学,通过分析采集到的样本,去反推总体的情况。样本完备、丰富,推理效果会越好。
- 样本的采集:
- 成本很高,不容易采集。
- 数据增强:
- 在样本采集完成的情况下,通过软件模拟,来生成假数据,丰富样本的多样性。
- 本质是: 给样本加上适当的噪声,模拟出不同场景的样本。
- 数据增强只发生在模型训练过程中,为了增加训练样本的多样性。
- 在正常推理时,不需要数据增强。
2. 数据预处理 VS 数据增强
-
预处理:
- Resize 缩放
- ToTensosr 转张量
- Normalize 规范化
-
数据增强:
- 随机裁剪
- 随机旋转
- 随机翻转
- 颜色抖动
- 图像混合
- …
-
训练时:预处理 + 数据增强
-
推理时:预处理
3. from torchvision import transforms
- transforms:
- 功能:预置了几乎所有的图像预处理和图像增强的方法
- 使用:跟其他的层是一样的
- 类 --> 对象
- 把对象当作函数使用
代码
import torch
from torch import nn
from PIL import Image
from torchvision import transforms
image = Image.open("./Russ.jpg")
# [w,h]
image.size
(231, 343)
image
# 缩放 size:[h,w]
resize = transforms.Resize(size=(200,400))
resize(image)
# 中心裁剪 size:[h,w]
center_crop = transforms.CenterCrop(size=(200,200))
center_crop(image)
center_crop = transforms.CenterCrop(size=(400,400))
center_crop(image)
# HSI 颜色空间 颜色抖动
color_jitter = transforms.ColorJitter(brightness=0.1,saturation=0.1,hue=0.1)
color_jitter(image)
# 串联处理方法
compose = transforms.Compose(transforms=[color_jitter,center_crop])
compose(image)
# 随机翻转
random_horizontal_flip = transforms.RandomHorizontalFlip(p=0.5)
random_vertical_flip = transforms.RandomVerticalFlip(p=0.5)
compose = transforms.Compose([random_horizontal_flip,random_vertical_flip])
compose(image)
# 随机裁剪
transforms.RandomCrop(size=(250,160))(image)
# 随机旋转 在 -10°到10°之间随机旋转
transforms.RandomRotation(degrees=10)(image)
# 转张量 [c,h,w]
transforms.ToTensor()(image).shape
torch.Size([3, 343, 231])
# 归一化
transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])(transforms.ToTensor()(image))
tensor([[[ 0.4118, 0.4039, 0.3961, ..., 0.1922, 0.1843, 0.1843],
[ 0.4118, 0.4039, 0.3961, ..., 0.1922, 0.1843, 0.1765],
[ 0.4118, 0.3961, 0.3961, ..., 0.1922, 0.1765, 0.1765],
...,
[ 0.7333, 0.7490, 0.7412, ..., 0.5843, 0.5922, 0.6000],
[ 0.7333, 0.7490, 0.7412, ..., 0.5686, 0.6078, 0.6078],
[ 0.7412, 0.7569, 0.7490, ..., 0.5529, 0.6157, 0.6078]],
[[ 0.3020, 0.2941, 0.2863, ..., 0.0824, 0.0745, 0.0745],
[ 0.3020, 0.2941, 0.2863, ..., 0.0824, 0.0745, 0.0667],
[ 0.3020, 0.2863, 0.2863, ..., 0.0824, 0.0667, 0.0667],
...,
[ 0.6314, 0.6471, 0.6392, ..., 0.4824, 0.4902, 0.4980],
[ 0.6314, 0.6471, 0.6392, ..., 0.4667, 0.5059, 0.5059],
[ 0.6392, 0.6549, 0.6471, ..., 0.4510, 0.5137, 0.5059]],
[[ 0.2000, 0.1922, 0.1843, ..., -0.0196, -0.0275, -0.0275],
[ 0.2000, 0.1922, 0.1843, ..., -0.0196, -0.0275, -0.0353],
[ 0.2000, 0.1843, 0.1843, ..., -0.0196, -0.0353, -0.0353],
...,
[ 0.5608, 0.5765, 0.5686, ..., 0.4118, 0.4196, 0.4275],
[ 0.5608, 0.5765, 0.5686, ..., 0.3961, 0.4353, 0.4353],
[ 0.5686, 0.5843, 0.5765, ..., 0.3804, 0.4431, 0.4353]]])
# 训练时的数据处理(预处理 + 数据增强)
train_trans = transforms.Compose(transforms=[
# 图像增强
transforms.ColorJitter(brightness=0.1, saturation=0.1, hue=0.1),
transforms.RandomCrop(size=(200, 200)),
transforms.RandomRotation(degrees=10),
transforms.RandomHorizontalFlip(p=0.5),
# 预处理
transforms.Resize(size=(256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
train_trans(image)
tensor([[[-1., -1., -1., ..., -1., -1., -1.],
[-1., -1., -1., ..., -1., -1., -1.],
[-1., -1., -1., ..., -1., -1., -1.],
...,
[-1., -1., -1., ..., -1., -1., -1.],
[-1., -1., -1., ..., -1., -1., -1.],
[-1., -1., -1., ..., -1., -1., -1.]],
[[-1., -1., -1., ..., -1., -1., -1.],
[-1., -1., -1., ..., -1., -1., -1.],
[-1., -1., -1., ..., -1., -1., -1.],
...,
[-1., -1., -1., ..., -1., -1., -1.],
[-1., -1., -1., ..., -1., -1., -1.],
[-1., -1., -1., ..., -1., -1., -1.]],
[[-1., -1., -1., ..., -1., -1., -1.],
[-1., -1., -1., ..., -1., -1., -1.],
[-1., -1., -1., ..., -1., -1., -1.],
...,
[-1., -1., -1., ..., -1., -1., -1.],
[-1., -1., -1., ..., -1., -1., -1.],
[-1., -1., -1., ..., -1., -1., -1.]]])
# 推理时的数据处理(预处理)
infer_trans = transforms.Compose(transforms=[
# 预处理
transforms.Resize(size=(256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
infer_trans(image)
tensor([[[ 0.4118, 0.4039, 0.3961, ..., 0.1922, 0.1843, 0.1843],
[ 0.4118, 0.3961, 0.3961, ..., 0.1922, 0.1765, 0.1765],
[ 0.4118, 0.3961, 0.3961, ..., 0.1843, 0.1765, 0.1686],
...,
[ 0.7333, 0.7490, 0.7412, ..., 0.5922, 0.5843, 0.6000],
[ 0.7333, 0.7490, 0.7412, ..., 0.5843, 0.6000, 0.6000],
[ 0.7412, 0.7569, 0.7490, ..., 0.5686, 0.6157, 0.6078]],
[[ 0.3020, 0.2941, 0.2863, ..., 0.0824, 0.0745, 0.0745],
[ 0.3020, 0.2863, 0.2863, ..., 0.0824, 0.0667, 0.0667],
[ 0.3020, 0.2863, 0.2863, ..., 0.0745, 0.0667, 0.0588],
...,
[ 0.6314, 0.6471, 0.6392, ..., 0.4902, 0.4824, 0.4980],
[ 0.6314, 0.6471, 0.6392, ..., 0.4824, 0.4980, 0.4980],
[ 0.6392, 0.6549, 0.6471, ..., 0.4667, 0.5137, 0.5059]],
[[ 0.2000, 0.1922, 0.1843, ..., -0.0196, -0.0275, -0.0275],
[ 0.2000, 0.1843, 0.1843, ..., -0.0196, -0.0353, -0.0353],
[ 0.2000, 0.1843, 0.1843, ..., -0.0275, -0.0353, -0.0431],
...,
[ 0.5608, 0.5765, 0.5686, ..., 0.4196, 0.4118, 0.4275],
[ 0.5608, 0.5765, 0.5686, ..., 0.4118, 0.4275, 0.4275],
[ 0.5686, 0.5843, 0.5765, ..., 0.3961, 0.4431, 0.4353]]])