torch.utils.data.DataLoader
Data(dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_list=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None)
功能: 构建可迭代的数据装载器
- dataset:Dataset类,决定数据从哪读取以及如何读取
- batchsize:批大小
- num_works: 是否多进程读取数据
- shuffle:每个epoch是否乱序
- drop_list:当样本数不能被batchsize整除时,是否舍弃最后一批数据
torch.utils.data.Dataset
class Dataset(object):
def __getitem__(self, index):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
功能:
Dataset抽象类,所有自定义的Dataset都要继承它,并且复写__getitem__()函数。
getitem: 接收一个索引,返回一个样本
Pytorch数据读取机制
- 读那些数据?
- 从哪读数据?
- 怎么读数据?
代码调试
Debug常用按钮
1、设置断点,进行Debug调试
2、采用C按钮,跳转到dataloader.py中DataLoader类的__iter__(self)函数中,该处代码表示是否使用多进程。
3、以单进程为例,点击B按钮然后点击C按钮,进入单进程的类当中,在该类中,最重要的函数为__next__(self),该函数会获取index和data。该函数告诉我们读哪些数据。
4、将光标放在345行,即index=self._next_index()上,点击F按钮,然后点击C按钮进入self._next_index()函数中,查看该函数是如何获取index的。
5、再点击以下C按钮,我们进入到sampler.py中的BatchSampler类中。Sampler就是一个采样器,他就是用来告诉我们每个Batchsize该读取那些数据
5、点击两次E按钮,跳出函数,然后点击B按钮,运行345行的代码,运行完成后我们的index就挑选出来了(Batchsize=16)。
6、有了index,接下来就是数据获取,我们点击B按钮进入self.dataset_fetcher.fetch()函数中
7、我们进入到fetch.py文件中的_MapDatasetFetcher类中,在第44行中正式调用了dataset,对dataset输入一个索引index,就会返回一个data,将一些列的data拼接为一个list。
8、先点击B按钮,运行到44行,然后点击C按钮(需要点击两次)进到self.dataset中。我们可以看到该函数跳转到了我们自己创建的my_dataset.py文件中的RMB数据集类中的__getitem__(self, index)函数中,self.data_inifo的每一项为图片的路径和标签,然后我们通过Image.open来读取图片,这就实现了一个数据的读取和标签的获取。
9、点击E按钮跳出该函数,进入到第7步的界面中,我们将光标放在47行并点击F按钮运行到该行,我们可以发现,在我们将读取的数据返回到第六步之前,我们会用self.collate_fn()函数来整理数据,该函数为数据的整理器,它会将我们读取的16个数据整理为一个Batch的形式,可以看到在运行self.collate_fn()函数之前,我们的data为list类型的数据。
10、点击两次B按钮,我们可以发现我们的data变成batch的形式,第一个元素里面为图片Tensor,第二个为标签。
11、点击F按钮返回数据并点击B按钮,此时我们可以看到我们的data为list的形式,第一个元素为图像,第二个元素为标签。有了图像和标签我们就可以对模型进行训练。这就是pytorch的数据读取机制。
现在我们回答上面的三个问题:
1、读那些数据?
Sampler输出的Index
2、从哪读数据?
Dataset中的data_dir
3、怎么读数据
Dataset中的getitem,根据索引读数据