本专题主要是解决Pytorch框架下项目的数据预处理工作
Table of Contents:
1. HDF5文件简介
2. Python中的_, __, __xx__区别
3. Dataset类
4. DataLoader类
DataLoader类是 torch.utils.data 库下的一个类,这是不用用户自定义的,直接调用即可,现在来简略谈谈这个类的功能,与一些关键代码。
先看看如何调用这个类的:
dataset_train_h5 = H5Dataset("./data_train", mode='train')
# dataset_h5 is what ?
trainloader = utilsdata.DataLoader(dataset_h5, batch_size=5, shuffle=True)
# trainloader is what ?
dataset_val_h5 = H5Dataset("./data_val", mode='val')
valloader = utilsdata.DataLoader(dataset_val_h5, batch_size=1, shuffle=False)
1. DataLoader类功能
由前面几节内容可知,直接调用 dataset[i] 不就可以返回训练样本了,为什么还要使用 DataLoader类呢?因为 dataset[i] 功能比较单一,或者说功能有限,而 DataLoader类可以将 dataset 装饰成迭代器,并且可以返回一个 batch 的数据。因此就可以用 enumerate 得到一个 batch 的数据data,由(images, labels)组成。
迭代器是访问集合元素的一种方式。迭代器对象从集合的第一个元素开始访问,知道所有的元素被访问完结束。
迭代器有两个基本方法:
__next__方法:返回迭代器的下一个(批次)元素
__iter__方法:返回迭代器对象本身(用来将一个可迭代对象转换为迭代器,“迭代器”指的是 iter 所返回的一个支持 next() 的对象)
2. DataLoader类关键代码
2.1 如何实现返回一个 batch 的数据呢?
前面说过,DataLoader类可以将 dataset 装饰成迭代器,然后再用 next() 或者 for 遍历数据。
过程:
- container = iter(list)
- container.next() # for
上面函数可以看出,next() 或 for 遍历数据时就是调用私有函数 def __next__():的。
其实,这些细节你不想深究也行,只要知道接口处传一个实参 batch_size 就返回一个迭代器,并且这个迭代器在遍历的时候一次会返回一个 batch_size 的样本数据。