点赞再看,养成习惯!
前言
继上一节对数据进行极其简单的数据分析后,这一节开始做数据加载,目标就是组织好数据,可以以一种正确的姿势喂给后续的模型。不同的深度学习框架,数据加载这一块是有所不同的,这里讲解的是PyTorch的数据处理工具。
正文
图像读取
这里主要介绍两个常用的库:
Pillow【轻量级】
Pillow是Python图像处理函式库(PIL)的一个分支。Pillow提供了常见的图像读取和处理的操作,而且可以与ipython notebook无缝集成,是应用比较广泛的库。
from PIL import Image
# 图像读取
im =Image.open(path)
OpenCV【重量级】
OpenCV是一个跨平台的计算机视觉库,最早由Intel开源得来。OpenCV发展的非常早,拥有众多的计算机视觉、数字图像处理和机器视觉等功能。OpenCV在功能上比Pillow更加强大很多,学习成本也高很多。
import cv2
# 图像读取
img = cv2.imread('cat.jpg')
# Opencv默认颜色通道顺序是BGR,转换一下
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
【小编友情提醒】
虽然python程序在使用opencv是导入cv2,但是真正用conda或者pip下载的库的名字叫opencv-python,这点要格外注意!
数据扩增
在深度学习中数据扩增方法非常重要,数据扩增可以增加训练集的样本,同时也可以有效缓解模型过拟合的情况,也可以给模型带来的更强的泛化能力。这里是针对图像数据进行扩增,所以常见的角度有图像颜色、尺寸、形态、空间和像素等。其实小编以前常见常用的也只有图像颜色变化、翻转、裁剪这三种操作。不过这里字符不可以进行翻转,例如6倒过来会变成9,改变字符原先的含义。
常见的库
- torchvision
pytorch官方提供的数据扩增库,提供了基本的数据数据扩增方法,可以无缝与torch进行集成;但数据扩增方法种类较少,且速度中等;
常用方法:
transforms.RandomCrop 随机区域裁剪
transforms.ColorJitter 对图像颜色的对比度、饱和度和零度进行变换
transforms.Grayscale 对图像进行灰度变换
transforms.Pad 使用固定值进行像素填充
transforms.RandomRotation 随机旋转
SVHNDataset(train_path, train_label,
transforms.Compose([
transforms.Resize((64, 128)),
transforms.ColorJitter(0.3, 0.3, 0.2), #颜色变化
transforms.RandomRotation(5), #随机旋转,不能旋转太多
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]))
- imgaug
imgaug是常用的第三方数据扩增库,提供了多样的数据扩增方法,且组合起来非常方便,速度较快; - albumentations
是常用的第三方数据扩增库,提供了多样的数据扩增方法,对图像分类、语义分割、物体检测和关键点检测都支持,速度较快。
图像扩增示例效果图:
PyTorch数据加载
PyTorch数据加载的过程是:数据集本身要转化成Dataset实例,而提供给模型训练、验证或测试时的读取要用DataLoader实例。
- Dataset:对数据集的封装,提供索引方式的对数据样本进行读取
- DataLoader:对Dataset进行封装,提供批量读取的迭代读取,可以用多进程加速
实施流程:
- 继承Dataset类,并实现__init__、getitem、__len__等函数成员,这里类名为SVHNDataset。
class SVHNDataset(Dataset):
def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path #所有图像数据路径
self.img_label = img_label #所有图像标签数据
if transform is not None:
self.transform = transform #预处理流
else:
self.transform = None
def __getitem__(self, index):
# just handle one data
img = Image.open(self.img_path[index]).convert('RGB') #读取图像
if self.transform is not None:
img = self.transform(img) #预处理
# 定长字符识别策略,填充的字符为10,这样不会与有效字符0-9发生碰撞
lbl = np.array(self.img_label[index], dtype=np.int)
lbl = list(lbl) + (5 - len(lbl)) * [10]
return img, torch.from_numpy(np.array(lbl[:5]))
def __len__(self):
return len(self.img_path) #数据集大小
- DataLoader加载SVHNDataset
train_loader = torch.utils.data.DataLoader(
SVHNDataset(train_path, train_label,
transforms.Compose([
transforms.Resize((64, 128)),
transforms.ColorJitter(0.3, 0.3, 0.2),
transforms.RandomRotation(5),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])),
batch_size=10, # 每批样本个数
shuffle=False, # 是否打乱顺序
num_workers=5, #进程个数
)
结语
PyTorch数据加载的流程较为固定,但因为Dataset能够自定义,所以数据读取就比较灵活。值得说一句的是,数据预处理的数据扩增并不是说直接扩增数据,比如把3W的训练集扩增到更多,而是在深度学习的训练过程中把每张图片都通过transform处理流进行变化,这样不同的迭代中同一索引的图像都不一定相同,从而达到了数据扩增的目标。
参考文献
- Pillow的官方文档:https://pillow.readthedocs.io/en/stable/
- OpenCV官网:https://opencv.org/
OpenCV Github:https://github.com/opencv/opencv
OpenCV 扩展算法库:https://github.com/opencv/opencv_contrib - torchvision: https://github.com/pytorch/vision
- imgaug: https://github.com/aleju/imgaug
- albumentations: https://albumentations.readthedocs.io
童鞋们,让小编听见你们的声音,点赞评论,一起加油。