Datawhale 零基础入门CV赛事-Task2 数据读取与数据扩增
文章目录
1 数据读取与数据扩增
数据增广是深度学习中常用的技巧之一,主要用于增加训练数据集,让数据集尽可能的多样化,使得训练的模型具有更强的泛化能力。目前数据增广主要包括:水平/垂直翻转,旋转,缩放,裁剪,剪切,平移,对比度,色彩抖动,噪声等。传统图像算法中,常用几何变换来进行数据增广,其中常用方法有:缩放,平移,旋转,仿射等。
2 常见的数据扩增方法
本文仅对OpenCV常用方法进行介绍。
在常见的数据扩增方法中,一般会从图像颜色、尺寸、形态、空间和像素等角度进行变换。当然不同的数据扩增方法可以自由进行组合,得到更加丰富的数据扩增方法。
以torchvision为例,常见的数据扩增方法包括:
- transforms.CenterCrop 对图片中心进行裁剪
- transforms.ColorJitter 对图像颜色的对比度、饱和度和零度进行变换
- transforms.FiveCrop 对图像四个角和中心进行裁剪得到五分图像
- transforms.Grayscale 对图像进行灰度变换
- transforms.Pad 使用固定值进行像素填充
- transforms.RandomAffine 随机仿射变换
- transforms.RandomCrop 随机区域裁剪
- transforms.RandomHorizontalFlip 随机水平翻转
- transforms.RandomRotation 随机旋转
- transforms.RandomVerticalFlip 随机垂直翻转
2.1 读入并显示图片
import cv2 # 导入Opencv库
import matplotlib.pyplot as plt
img = cv2.imread('image/cat.jpg') #读入图片
print(img.shape)
# Opencv默认颜色通道顺序是BRG,转换一下
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(4,2.25), dpi=300) #figsize设置图片大小,单位英寸,dpi设置分辨率
plt.imshow(img)
2.2 PyTorch数据增强(image transformations)
在介绍具体数据扩增方法前,先对PyTorch数据增强的内容做一个简单了解。
2.2.1 Compose
torchvision.transforms.Compose(transforms)
#用法:
transforms.Compose([
transforms.CenterCrop(10),
transforms.ToTensor(),
])
2.2.2 常见变化
Resize
图像尺寸变化
torchvision.transforms.Resize(size, interpolation=2)
标准化
对数据按通道进行标准化,即先减均值,再除以标准差,注意是 hwc(高度,宽度,通道数)
torchvision.transforms.Normalize(mean, std)
参数:
mean-均值
std-标准差
转换为 PILImage
将 tensor 或者 ndarray 的数据转换为 PIL Image 类型数据
torchvision.transforms.ToPILImage(mode=None)
参数:
mode- 为 None 时,为 1 通道, mode=3 通道默认转换为 RGB, 4 通道默认转换为 RGBA
转为 Tensor
将 PIL Image 或者 ndarray 转换为 tensor,并且归一化至[0-1]
torchvision.transforms.ToTensor
# 转换为 PILImage
img = cv2.imread('image/cat.jpg') #得到的img是一个ndarray类型,默认img为BRG
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #转化为RGB
img_PIL_obj= torchvision.transforms.ToPILImage() #对象实例化,
img_PIL=img_PIL_obj(img) #将ndarray数据类型转为PIL Image数据类型
#显示图片
plt.figure
plt.imshow(img_PIL)
#print(type(img))
#print(type(img_PIL_obj))
#print(img.shape)
2.2.3 裁剪 Crop
transforms.CenterCrop 中心裁剪
torchvision.transforms.CenterCrop(size) 依据给定的 size 从中心裁剪
import PIL.Image as Image #导入PIL库,使得读入数据为PIL数据类型
image=Image.open("image/cat.jpg") #读入图片RGB格式
crop_obj = transforms.CenterCrop((224,224)) # 对图片中心进行裁剪
image_center = crop_obj(image) # 对图片中心进行裁剪
print(image_center.size, image_center.format, image_center.mode)
#将裁剪之后的图片保存下来
image.save("image/cat.png", format='PNG')
#显示图片
plt.figure
plt.imshow(image_center)
(224, 224) None RGB