数据读取与数据扩增
在进行本章学习之前配置了baseline代码环境,但是环境配置好之后运行时在import torchvision时一直报DLL load failed。删除torchvision重新安装依然报错,最后将torchvision版本指定为0.2.1之后问题解决。目前遇到的问题是解压test图片压缩文件夹的时候,winzip报too many entries错误。猜测是zip里40000个entry超过了winzip的处理能力。重新order了7zip,目前还没有安装完成。
由于基本没有python基础,所以这里第一次见到python里定义了class。于是查了一下相关内容,原来python也是面向对象的语言。于是简单看了一下教程,比较简单,就不过多介绍了。具体参见:python class
数据读取
关于数据读取部分介绍了两个python库,Pillow和OpenCV。和之前的matlab,C++里的数据读取方法类似,简单略过了。
数据扩增
数据扩增是之前没有接触过的,其主要目的是增加训练集样本,可以使模型具有更强的泛化能力。
数据扩增方法有很多:从颜色空间、尺度空间到样本空间,同时根据不同任务数据扩增都有相应的区别。应该根据具体的任务对数据扩增的方法进行选择,不要选一些和当前任务不相关的扩增方法,或者不适合当前任务的方法,比如本次的字符识别任务如果进行图片翻转扩增就会使字符的识别结果发生改变,显然是不合理的。
对于图像分类,数据扩增一般不会改变标签;对于物体检测,数据扩增会改变物体坐标位置;对于图像分割,数据扩增会改变像素标签。所以根据具体的任务,数据扩增之后也要对标签进行相应的处理。
常用的数据扩增库:
-
torchvision
pytorch官方提供的数据扩增库,提供了基本的数据数据扩增方法,可以无缝与torch进行集成;但数据扩增方法种类较少,且速度中等; -
imgaug
imgaug是常用的第三方数据扩增库,提供了多样的数据扩增方法,且组合起来非常方便,速度较快; -
albumentations
是常用的第三方数据扩增库,提供了多样的数据扩增方法,对图像分类、语义分割、物体检测和关键点检测都支持,速度较快。
Pytorch读取数据
Baseline的代码里有一个SVHNDataset类,它通过继承torch.utils.data.dataset 来对数据进行封装。代码与说明如下:
class SVHNDataset(Dataset):
# __init__函数类似C++中的构造函数,当实例化类对象的时候,自动调用该函数。不过用法比C++简单多了。
def __init__(self, img_path, img_label, transform=None):
self.img_path = img_path
self.img_label = img_label
# 这里用self.transform = transform是不就可以了,不过没有进行实际验证。
if transform is not None:
self.transform = transform
else:
self.transform = None
def __getitem__(self, index):
# PIL读取数据数RGB形式,这里为什么还要convert?
img = Image.open(self.img_path[index]).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# 原始SVHN中类别10为数字0
lbl = np.array(self.img_label[index], dtype=np.int)
lbl = list(lbl) + (5 - len(lbl)) * [10]
return img, torch.from_numpy(np.array(lbl[:5]))
def __len__(self):
return len(self.img_path)