PyTorch 1.x 数据IO及预处理
1. 简介
2. 加载数据
- Pytorch的数据读取主要包含三个类:
- torch.utils.data.Dataset
- torch.utils.data.DataLoader
- torch.utils.data.dataloader.DataLoaderIter
- 三者关系:
- torch.utils.data.Dataset被装入 torch.utils.data.DataLoader
- torch.utils.data.DataLoader被装入 torch.utils.data.dataloader.DataLoaderIter
- 伪代码示例
class CustomDataset(Dataset):
# 自定义自己的dataset
dataset = CustomDataset()
dataloader = Dataloader(dataset, ...)
for data in dataloader:
# training...
- 在上面的for循环中,主要操作过程:
- 调用了dataloader 的
__iter__
() 方法, 产生了一个DataLoaderIter
- 反复调用DataLoaderIter 的
__next__()
来得到batch, 具体操作就是, 多次调用Dataset的__getitem__()
方法 (如果num_worker>0
就多线程调用), 然后用collate_fn
来把它们打包成batch. 中间还会涉及到shuffle
, 以及sample
的方法等, 这里就不多说了. - 当数据读完后,
__next__()
抛出一个StopIteration异常, for循环结束, dataloader 失效
- 调用了dataloader 的
2.1 torch.utils.data.Dataset
- 功能
- 主要功能:读取一个样本Data和Label
- 代表数据集的抽象类
- 所有表示从Key到数据样本进行映射的数据集(即字典数据类型)都必须是此类的派生类
- 所有派生类都应重载
__getitem __()
,以获取指定Key的数据样本 - 派生类还可重载
__len__()
,以返回整个数据集的长度
- 使用方法
- 通过Index或Key,获取对应的数据样本和GT标签
- 示例
class my_dataset(torch.utils.data.Dataset):
def __init__(self, trainingImageDir, bndbox, keypointsPixel, keypointsWorld, center):
self.trainingImageDir = trainingImageDir
self.mean = Img_mean
self.std = Img_std
self.bndbox = bndbox
self.keypointsPixel = keypointsPixel
self.keypointsWorld = keypointsWorld
self.center = center
self.depth_thres = 0.4
def __getitem__(self, index):
# data4DTemp = scio.loadmat(self.trainingImageDir + str(index+1) + '.mat')['DepthNormal']
data4DTemp = scio.loadmat(self.trainingImageDir + str(index) + '.mat')['DepthNormal']
depthTemp = data4DTemp[:,:,3]
data, label = dataPreprocess(index, depthTemp, self.keypointsPixel, self.keypointsWorld, self.bndbox, self.center, self.depth_thres)
return data, label
def __len__(self):
return len(self.bndbox)
2.2 torch.utils.data.DataLoader
- 定义
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
- 参数说明:
参数 | 说明 |
---|---|
dataset | 从中加载数据的数据集 |
batch_size | 每批次要加载的样本数 |
shuffle | 设置为True以使数据在每个训练epoch都重新洗牌 |
collate_fn | 这个函数用来打包batch |
sampler | 定义从数据集中抽取样本的策略。 如果指定,则shuffle必须为False |
batch_sampler | 类似于采样器(sampler),但一次返回一批索引。 与batch_size,shuffle,sampler和drop_last互斥 |
num_workers | 用于数据加载的子进程数量。 0表示将在主进程中加载数据。 (默认值:0) |
collate_fn | 合并样本列表以形成小批量 |
pin_memory | 如果为True,则数据加载器在将张量返回之前将其复制到CUDA固定的内存中 |
drop_last | 如果数据集大小不能被批量大小整除,则设置为True以删除最后一个不完整的批量。 如果为False并且数据集的大小不能被批次大小整除,则最后一批将较小(默认值:False) |
timeout | 如果为正,则为从工作进程收集批次的超时值。 应始终为非负数 (默认值:0) |
worker_init_fn | 如果不为None,则在种子创建之后和数据加载之前,将在每个工作子进程上以工作ID([0,num_workers-1]中的int)作为输入来调用此方法 (默认值:无) |
- 功能
- 是torch.utils.data.dataloader.DataLoaderIter的一个框架
- 定义了一堆成员变量, 到时候赋给DataLoaderIter
- 然后有一个
__iter__()
函数, 把自己 "装进"DataLoaderIter 里面
def __iter__(self):
return DataLoaderIter(self)
-
工作流程
- 第一步:通过Dataset类里面的
__getitem__
函数获取单个的数据 - 第二步:把多个“单个数据”组合成batch
- 第三步:使用collate_fn所指定的函数对这个batch做一些操作,比如padding啊之类的
- 第一步:通过Dataset类里面的
-
示例
test_image_datasets = my_dataset(testingImageDir, bndbox_test, keypointsPixeltest, keypointsWorldtest, center_test)
test_dataloaders = torch.utils.data.DataLoader(test_image_datasets, batch_size = batch_size,
shuffle = False, num_workers = 8)