经过上面的一个例子,相信已经能够了解pytorch的构建并训练的基本步骤了,但是,这里还有几个问题,如果我们的数据并不像上面展示的那样每个文件夹存储一个类别的文件,而是使用的索引的方式存储的类别和文件名呢?如果我们的数据时文本序列呢?这又怎么进行数据的读取以及数据管道的构建呢?
1. 迭代器
Python的迭代器和生成器是一个非常好的能够构建数据管道的方式,试想一下,如果全部的数据都一股脑的读入内存,那么当数据量极大的时候内存绝对会爆掉的,这个时候迭代器和生成器的优势就极大地体现出来了,因为迭代器和生成器一个函数,每次都利用函数记住取到了哪个位置的数据,下次取得时候直接从上次取出的地方再取数据即可。
构建迭代器时最重要的就是需要重写两个方法,__iter__
和 __next__
。StopIteration 异常用于标识迭代的完成,防止出现无限循环的情况,在 next() 方法中我们可以设置在完成指定循环次数后触发 StopIteration 异常来结束迭代。下面我构建一个迭代器当作示例:
class MyNumbers:
def __init__(self, list):
# 获取传入的列表
self.list = list
# 当前的索引
self.index = 0
# 列表的长度,即索引的界限
self.n = len(list)
def __iter__(self):
# 返回迭代器本身
return self
def __next__(self):
if self.index < self.n:
num = self.list[self.index]
self.index += 1
return num
else:
raise StopIteration
itera_list = MyNumbers([1,4,6,9])
for i in itera_list:
print(i)
以上迭代器输出为:
1
4
6
9
2. 生成器
生成器与迭代器其实差不多,可以理解为生成器是迭代器的子类,是用 yield
关键字来定义的。一个带有 yield 的函数就是一个 generator,它和普通函数不同,生成一个 generator 看起来像函数调用,但不会执行任何函数代码,直到对其调用 next()(在 for 循环中会自动调用 next())才开始执行。当函数执行结束时,generator 自动抛出 StopIteration 异常,表示迭代完成。在 for 循环里,无需处理 StopIteration 异常,循环会正常结束。
虽然执行流程仍按函数的流程执行,但每执行到一个 yield 语句就会中断,并返回一个迭代值,下次执行时从 yield 的下一个语句继续执行。看起来就好像一个函数在正常执行的过程中被 yield 中断了数次,每次中断都会通过 yield 返回当前的迭代值。
下面我来示例一下用生成器生成一个1-10数字的平方的函数:
def gengerate():
for i in range(1, 11):
yield i **2
for i in gengerate():
print(i)
可以看到,生成器本质上就是一个迭代器的再封装。
3. Dataset
一般来说PyTorch中深度学习训练的流程是这样的:
- 创建Dateset,Dataset负责建立索引到样本的映射
- Dataset传递给DataLoader
- DataLoader迭代产生训练数据提供给模型
让我们来来想象一下,如果让你来从数据集中构建一个batch,你会怎么做?
- 首先得知道数据集的长度
n
吧,否则你一个batch一个batch的取都不知道什么时候取完了 - 然后再从
0 - n-1
的数据集中一个不断的取batch个数据的下标,这里取数据的方法可以是随机的,也可以是顺序的 - 从取得的下标中得到下标对应的数据
- 将数据个标签整合为一个元组进行输出
以上的流程其实也是 Dataset
实现的流程,Dataset
支持自定义数据集,但是必须要继承 Dataset
,然后必须重写 __len__()
方法和 __getitem__(index)
方法,前者是获取到数据集的大小,后者是读取数据并预处理数据,返回数据和标签(如果有标签)。
接下来我们看一个自定义Dataset的例子:
from torch.utils.data import Dataset
class face_dataset(Dataset):
def __init__(self):
self.file_path = './data/faces/'
f=open("final_train_tag_dict.txt","r")
self.label_dict=eval(f.read())
f.close()
def __getitem__(self,index):
label = list(self.label_dict.values())[index-1]
img_id = list(self.label_dict.keys())[index-1]
img_path = self.file_path+str(img_id)+".jpg"
img = np.array(Image.open(img_path))
return img,label
def __len__(self):
return len(self.label_dict)
当然,自定义Dataset里面传入的数据是处理好的数据,比如经过了归一化、数据增强等操作后的数据。
4. DataLoader
DataLoader
就是负责以特定的方式从数据集中迭代的产生一个个batch的样本集合,所以该函数是不需要进行自定义的,就负责用这个函数就行,下面对 DataLoader
的一些参数进行解释。
DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None,
)
一般情况下,我们仅会配置 dataset
, batch_size
, shuffle
, num_workers
, drop_last
这五个参数,其他参数使用默认值即可。
dataset
: 数据集batch_size
: 批次大小shuffle
: 是否乱序sampler
: 样本采样函数,一般无需设置。batch_sampler
: 批次采样函数,一般无需设置。num_workers
: 使用多进程读取数据,设置的进程数。collate_fn
: 整理一个批次数据的函数。pin_memory
: 是否设置为锁业内存。默认为False,锁业内存不会使用虚拟内存(硬盘),从锁业内存拷贝到GPU上速度会更快。drop_last
: 是否丢弃最后一个样本数量不足batch_size批次数据。timeout
: 加载一个数据批次的最长等待时间,一般无需设置。worker_init_fn
: 每个worker中dataset的初始化函数,常用于 IterableDataset。一般不使用。