一、前言
在深度学习中,需要加载数据对神经网络进行训练,现有的主流数据集及常用的经典数据集例如COCO,MINIST,CIFAR等,在许多开源的项目中例如MMCV,torchvision中都有对应的加载,对于自己的数据集而言,应该如何加载自定义数据集呢。
torchvision.datasets - PyTorch master documentationpytorch.org本文就图片数据集的加载进行分析,数据集为转化为图片和json标注文件的CIFAR10数据集,数据集的文件格式如下所示torchvision.datasets - PyTorch master documentation本文就图片数据集的加载进行分析,数据集为转化为图片和json标注文件的CIFAR10数据集,数据集的文件格式如下所示
将CIFAR10数据集转换成图片文件和json文件的标注参照这篇文章:
HUST小菜鸡:CIFAR10数据集转换成图片及标注文件zhuanlan.zhihu.com二、直接读取
#读取文件位置
def get_path('path-str'):
...
return file_path
#读取图片
def loader_img(file_path):
#根据图片的位置读取图片并返回读取的图片和标签
#对图片进行处理
...
return imgs_list, label_list
#获取batchsize大小的数据
def get_train_data(imgs_list,label_list,batchsize):
...
return img[1],img[2],...,img[batchsize]
以上是常规的思路,在原理上来说是可行的,但是如果batchsize很大,那么用这种方式去读取数据集会带来如下弊端:
- 将所有的图像数据直接加载到numpy数据中会占用大量的内存
- 由于需要对数据进行导入,每次训练的时候在数据读取阶段会占用大量的时间
- 只使用了单线程去读取,读取效率比较低下
- 拓展性很差,只能对数据进行一些单一的预处理
PyTorch中有工具函数torch.utils.Data.DataLoader,通过这个函数我们在准备加载数据集使用mini-batch的时候可以使用多线程并行处理,这样可以加快我们准备数据集的速度。Datasets就是构建这个工具函数的实例参数之一。这样我们就可以批量加载数据或者并行加载数据
三、class Datasets
class Dataset(object):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite