Task02:数据读取与数据扩增
数据读取与数据扩增的目的
数据读取
在采用Pytorch的前提下,需要针对训练集和验证集分别构建DataLoader,这样才能将数据送入模型完成loss计算、梯度计算、反向传播和参数更新。因此数据读取就是为了准备好封装数据传入的容器。
数据扩增
通过对已有的数据库中的图片做一些微小的处理,例如调整图像的颜色、尺寸、像素、空间等等特征以达到丰富数据的目的,这样不仅可以增大训练数据数量,降低训练方差,还能提高模型的泛化能力。
数据读取与数据扩增的方法
数据读取
图片的读取可以通过PIL或者OpenCV两种库进行读取,基本语句如下:
PIL
from PIL import Image
im = Image.open('filepath')
OpenCV
import cv2
im = cv2.imread('filepath')
im = im.cvtColor(im,cv2.COLOR_BGR2RGB)
数据扩增
常用的数据扩增的方法有随机裁剪、对比度亮度饱和度调整、翻转、旋转、像素填充等等,这些操作可以通过简单的操作就能达到丰富数据的目的。
但是本例中需要注意不能对图片实现翻转操作,因为在数字识别的时候,6翻转之后就变成了9,这样会引起歧义。
下面是列举一些常用的图片操作。首先导入相关的库。
from PIL import Image,ImageFilter
import numpy as np
import glob
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision.transforms import functional as TF
然后读取一张图片
train_path = glob.glob('../input/train/*.png')
im = Image.open(train_path[0])
- transforms.CenterCrop 对图片中心进行裁剪
crop_obj = transforms.CenterCrop((300, 300))
crop_obj(im)
- transforms.ColorJitter 对图像颜色的对比度、饱和度和零度进行变换
ColorJitter_obj = transforms.ColorJitter(brightness=4, contrast=0.5, saturation=0.1, hue=0.2)
plt.figure(figsize=(10,10))
plt.subplot(121)
plt.imshow(im)
plt.subplot(122)
plt.imshow(ColorJitter_obj(im))
- transforms.FiveCrop 对图像四个角和中心进行裁剪得到五分图像
FiveCrop_obj = transforms.FiveCrop(100)
plt.figure(figsize=(10,10))
for i, img in enumerate(FiveCrop_obj(im)):
plt.subplot(1,5,i+1)
plt.imshow(img)
plt.figure()
plt.imshow(im)
- transforms.Grayscale 对图像进行灰度变换
Grayscale_obj = transforms.Grayscale()
plt.figure(figsize=(10,10))
plt.subplot(121)
plt.imshow(im)
plt.subplot(122)
plt.imshow(Grayscale_obj(im),cmap='gray')
- transforms.RandomAffine 随机仿射变换
RandomAffine_obj = transforms.RandomAffine(20)
plt.figure(figsize=(10,10))
plt.subplot(121)
plt.imshow(im)
plt.subplot(122)
plt.imshow(RandomAffine_obj(im))
- transforms.RandomCrop 随机区域裁剪
RandomCrop_obj = transforms.RandomCrop((100,200))
plt.figure(figsize=(10,10))
plt.subplot(121)
plt.imshow(im)
plt.subplot(122)
plt.imshow(RandomCrop_obj(im))
- transforms.RandomRotation 随机旋转
RandomRotation_obj = transforms.RandomRotation(30)
plt.figure(figsize=(10,10))
plt.subplot(121)
plt.imshow(im)
plt.subplot(122)
plt.imshow(RandomRotation_obj(im))