还是拿来自学用 多谢多谢 对于torch本人无人讨论 拿来主义的 多谢理解
0,Dataset和DataLoader功能简介
Pytorch通常使用Dataset和DataLoader这两个工具类来构建数据管道。
Dataset定义了数据集的内容,它相当于一个类似列表的数据结构,具有确定的长度,能够用索引获取数据集中的元素。
而DataLoader定义了按batch加载数据集的方法,它是一个实现了__iter__
方法的可迭代对象,每次迭代输出一个batch的数据。
DataLoader能够控制batch的大小,batch中元素的采样方法,以及将batch结果整理成模型所需输入形式的方法,并且能够使用多进程读取数据。
在绝大部分情况下,用户只需实现Dataset的__len__
方法和__getitem__
方法,就可以轻松构建自己的数据集,并用默认数据管道进行加载。
对于一些复杂的数据集,用户可能还要自己设计 DataLoader中的 collate_fn方法以便将获取的一个批次的数据整理成模型需要的输入形式。
一,深入理解Dataset和DataLoader原理
1,获取一个batch数据的步骤
让我们考虑一下从一个数据集中获取一个batch的数据需要哪些步骤。
(假定数据集的特征和标签分别表示为张量X
和Y
,数据集可以表示为(X,Y)
, 假定batch大小为m
)
1,首先我们要确定数据集的长度n
。结果类似:n = 1000
。
2,然后我们从0
到n-1
的范围中抽样出m
个数(batch大小)。假定m=4
, 拿到的结果是一个列表,类似:indices = [1,4,8,9]
3,接着我们从数据集中去取这m
个数对应下标的元素。拿到的结果是一个元组列表,类似:samples = [(X[1],Y[1]),(X[4],Y[4]),(X[8],Y[8]),(X[9],Y[9])]
4,最后我们将结果整理成两个张量作为输出。
拿到的结果是两个张量,类似batch = (features,labels)
,
其中 features = torch.stack([X[1],X[4],X[8],X[9]])
labels = torch.stack([Y[1],Y[4],Y[8],Y[9]])
2,Dataset和DataLoader的功能分工
上述第1个步骤确定数据集的长度是由 Dataset的__len__
方法实现的。
第2个步骤从0
到n-1
的范围中抽样出m
个数的方法是由 DataLoader的 sampler
和 batch_sampler
参数指定的。
sampler
参数指定单个元素抽样方法,一般无需用户设置,程序默认在DataLoader的参数shuffle=True
时采用随机抽样,shuffle=False
时采用顺序抽样。
batch_sampler
参数将多个抽样的元素整理成一个列表,一般无需用户设置,默认方法在DataLoader的参数drop_last=True
时会丢弃数据集最后一个长度不能被batch大小整除的批次,在drop_last=False
时保留最后一个批次。
第3个步骤的核心逻辑根据下标取数据集中的元素 是由 Dataset的 __getitem__
方法实现的。
第4个步骤的逻辑由DataLoader的参数collate_fn
指定。一般情况下也无需用户设置。
Dataset和DataLoader的一般使用方式如下:
将DataLoader内部调用方式步骤拆解如下:
3,Dataset和DataLoader的核心源码
以下是 Dataset和 DataLoader的核心源码,省略了为了提升性能而引入的诸如多进程读取数据相关的代码。
我们来测试一番
完美, 和预期一致!
二,使用Dataset创建数据集
Dataset创建数据集常用的方法有:
- 使用 torch.utils.data.TensorDataset 根据Tensor创建数据集(numpy的array,Pandas的DataFrame需要先转换成Tensor)。
- 使用 torchvision.datasets.ImageFolder 根据图片目录创建图片数据集。
- 继承 torch.utils.data.Dataset 创建自定义数据集。
此外,还可以通过
- torch.utils.data.random_split 将一个数据集分割成多份,常用于分割训练集,验证集和测试集。
- 调用Dataset的加法运算符(
+
)将多个数据集合并成一个数据集。
1,根据Tensor创建数据集
2,根据图片目录创建图片数据集
3,创建自定义数据集
下面我们通过另外一种方式,即继承 torch.utils.data.Dataset 创建自定义数据集的方式来对 cifar2构建 数据管道。
三,使用DataLoader加载数据集
DataLoader能够控制batch的大小,batch中元素的采样方法,以及将batch结果整理成模型所需输入形式的方法,并且能够使用多进程读取数据。
DataLoader的函数签名如下。
一般情况下,我们仅仅会配置 dataset, batch_size, shuffle, num_workers,pin_memory, drop_last这六个参数,
有时候对于一些复杂结构的数据集,还需要自定义collate_fn函数,其他参数一般使用默认值即可。
DataLoader除了可以加载我们前面讲的 torch.utils.data.Dataset 外,还能够加载另外一种数据集 torch.utils.data.IterableDataset。
和Dataset数据集相当于一种列表结构不同,IterableDataset相当于一种迭代器结构。它更加复杂,一般较少使用。
- dataset : 数据集
- batch_size: 批次大小
- shuffle: 是否乱序
- sampler: 样本采样函数,一般无需设置。
- batch_sampler: 批次采样函数,一般无需设置。
- num_workers: 使用多进程读取数据,设置的进程数。
- collate_fn: 整理一个批次数据的函数。
- pin_memory: 是否设置为锁业内存。默认为False,锁业内存不会使用虚拟内存(硬盘),从锁业内存拷贝到GPU上速度会更快。
- drop_last: 是否丢弃最后一个样本数量不足batch_size批次数据。
- timeout: 加载一个数据批次的最长等待时间,一般无需设置。
- worker_init_fn: 每个worker中dataset的初始化函数,常用于 IterableDataset。一般不使用。