数据读取
导入需要的包以及文件路径
import json, glob
import numpy as np
from PIL import Image
import torch
from torch.utils.data.dataset import Dataset
import torchvision.transforms as transforms
train_path = glob.glob("../mchar_train/*.png")
label_path = "../mchar_train.json"
train_json = json.load(open(label_path))
train_label = [train_json[x]['label'] for x in train_json]
图像读取
1.pillow
Pillow是Python图像处理函式库(PIL)的一个分支。Pillow提供了常见的图像读取和处理的操作,而且可以与ipython notebook无缝集成,是应用比较广泛的库。pillow官方文档
2.opencv
OpenCV是一个跨平台的计算机视觉库,最早由Intel开源得来。OpenCV发展的非常早,拥有众多的计算机视觉、数字图像处理和机器视觉等功能。OpenCV在功能上比Pillow更加强大很多,学习成本也高很多。opencv官网,opencv GitHub,OpenCV 扩展算法库
数据读取
pytorch有Dataset类,自己定义的Dataset都要继承这个类。参考Pytorch中文文档。这个类必须有getitem函数用来索引数据和len函数用来判断大小。transform用来对图像做变换
class SVHNDataset(Dataset):
def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path
self.img_label = img_label
if transform is None:
self.transform = None
else:
self.transform = transform
def __getitem__(self, item):
img = Image.open(self.img_path[item]).convert('RGB')
if self.transform is not None:
img = self.transform(img)
label = self.img_label[item]
label = list(label) + [10]*(5-len(label))
return img, torch.from_numpy(np.array(label[:5]))
def __len__(self):
return len(self.img_path)
数据扩增
在深度学习模型的训练过程中,数据扩增是必不可少的环节。现有深度学习的参数非常多,一般的模型可训练的参数量基本上都是万到百万级别,而训练集样本的数量很难有这么多。
其次数据扩增可以扩展样本空间,假设现在的分类模型需要对汽车进行分类,左边的是汽车A,右边为汽车B。如果不使用任何数据扩增方法,深度学习模型会从汽车车头的角度来进行判别,而不是汽车具体的区别。
在常见的数据扩增方法中,一般会从图像颜色、尺寸、形态、空间和像素等角度进行变换。当然不同的数据扩增方法可以自由进行组合,得到更加丰富的数据扩增方法。
以torchvision为例,常见的数据扩增方法包括:
- transforms.CenterCrop 对图片中心进行裁剪
- transforms.ColorJitter 对图像颜色的对比度、饱和度和零度进行变换
- transforms.FiveCrop 对图像四个角和中心进行裁剪得到五分图像
- transforms.Grayscale 对图像进行灰度变换
- transforms.Pad 使用固定值进行像素填充
- transforms.RandomAffine 随机仿射变换
- transforms.RandomCrop 随机区域裁剪
- transforms.RandomHorizontalFlip 随机水平翻转
- transforms.RandomRotation 随机旋转
- transforms.RandomVerticalFlip 随机垂直翻转
在Pytorch中,数据扩增在Dataloader中实现。Dataloader是数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。
train_loader = torch.utils.data.DataLoader(
SVHNDataset(train_path, train_label,
transforms.Compose([
transforms.Resize((64, 128)),
transforms.ColorJitter(0.3, 0.3, 0.2),
transforms.RandomRotation(degrees=5),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])),
batch_size=10, # 每批样本个数
shuffle=False, # 是否打乱顺序
num_workers=0, # 读取的线程个数
)
以上为数据读取方式,本博客内容均来源于DataWhale