【PyTorch深度学习实践】学习笔记 数据集的加载Dataset和DataLoader原理

简而言之,这俩就是自动帮我们取数据,避免了接触底层代码

1、前言

机器学习模型训练五大步骤;第一是数据,第二是模型,第三是损失函数,第四是优化器,第五个是迭代训练过程。
这里主要学习数据模块当中的数据读取,数据模块通常还会分为四个子模块:数据收集、数据划分、数据读取、数据预处理。
在这里插入图片描述
在进行实验之前,需要收集数据,数据包括原始样本和标签;
有了原始数据之后,需要对数据集进行划分,把数据集划分为训练集、验证集和测试集;训练集用于训练模型,验证集用于验证模型是否过拟合,也可以理解为用验证集挑选模型的超参数,测试集用于测试模型的性能,测试模型的泛化能力;
第三个子模块是数据读取,也就是即将要学习的DataLoader,pytorch中数据读取的核心就是DataLoader
第四个子模块是数据预处理,把数据读取进来往往还需要对数据进行一系列的图像预处理,比如数据的中心化,标准化,旋转或者翻转等等。pytorch中数据预处理是通过transforms进行处理的

详情请见原文链接

经过debug实践和总结后,如下。

  • 子模块DataLoader还会细分为两个子模块,Sampler和DataSet;Sample的功能是生成索引,也就是样本的序号;Dataset是根据索引去读取图片以及对应的标签;

2、DataLoader and Dataset详解

2.1 DataLoader

(1)torch.utils.data.DataLoader
功能:构建可迭代的数据装载器;

dataset: Dataset类,决定数据从哪里读取及如何读取; batchsize:批大小;
num_works:
是否多进程读取数据;
shuffle: 每个epoch是否乱序;
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据;

DataLoader( dataset = dataset , #dataset 是继承了dataset类之后加载数据集提供路径
            batch_size = 32, #选择batch_size的大小
            shuffle = true,  #增强数据集随机性
            num_workers = 2 ) #多进程读数据

再次强调

  1. Epoch:所有训练样本都已输入到模型中,称为一个Epoch;
  2. Iteration: 一批样本输入到模型中,称之为一个Iteration;
  3. Batchsize: 批大小,决定一个Epoch中有多少个Iteration;

(2)torch.utils.data.Dataset
Dataset是用来定义数据从哪里读取,以及如何读取的问题;
功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__();
getitem:接收一个索引,返回一个样本

解决三个问题

下面对人民币二分类的数据进行读取,从三个方面了解pytorch的读取机制,分别为读哪些数据、从哪读数据、怎么读数据;
在这里插入图片描述

1.读哪些数据

具体来说,在每一个Iteration的时候应该读取哪些数据,每一个Iteration读取一个Batch大小的数据,假如有80个样本,那么从80个样本中读取8个样本,那么应该读取哪八个样本,这就是我们的第一个问题,读哪些数据;

2.从哪读数据

意思是在硬盘当中,我们应该怎么找到对应的数据,在哪里设置参数;

3.怎么读数据

从代码中学习可以发现,数据的获取是从DataLoader迭代器中不停地去获取一个Batchsize大小的数据,通过for循环获取的;
下面开始debug调试看读取数据的过程。
首先在pycharm中对

for i, data in enumerate(train_loader):

这一行代码设置断点,然后执行Debug,然后点击步进功能键,就可以跳转到对应的函数中,可以发现是跳到了dataloader.py中的__iter__()函数;具体如下所示:

def __iter__(self):
	if self.num_workers == 0:
		return _SingleProcessDataLoaderIter(self)
	else:
		return _MultiProcessingDataLoaderIter(self) #进程问题
  • 这段代码是一个if的判断语句,其功能是判断是否采用多进程;如果采用多进程,有多进程的读取机制;如果是单进程,有单进程的读取机制;这里以单进程进行演示;
  • 单进程当中,最主要的是__next__()函数
    (这里涉及到python里的迭代器的理解)在next中会获取index和data,回想一下数据读取中的三个问题,第一个问题是读哪些数据;__next__函数就告诉我们,在每一个Iteration当中读取哪些数据
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
	def __init__(self,loader):
		super(_SingleProcessDataLoaderIter,self).__init__(loader)
		assert self.timeout == 0
		assert self.num_workers == 0
		self.dataset_fetcher = _DatasetKind.create_fetcher(self.dataset_kind, self.dataset,self.auto_collation, self.collate_fn, self.drop_last)

    def __next__(self):
        index = self._next_index()  # may raise StopIteration
        data = self.dataset_fetcher.fetch(index)  # may raise StopIteration
        if self.pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

    next = __next__  # Python 2 compatibility
  • 现在将光标对准def next(self)中的 index=self._next_index(),点击功能区中的run to
    cursor,然后程序就会运行到这一行,点击功能区中的step
    into,进入到_next_index()函数中了解是怎么获得数据的index的;之后代码会跳到下面的代码中
   def _next_index(self):
        return next(self.sampler_iter)  # may raise StopIteration
  • 再点击一下step into就进入了sampler.py文件中,sampler是一个采样器其功能是告诉我们每一个batch_size应该读取哪些数据(回答了问题1!),如挑选出一个Iteration中的index,因为bitch_size的值是16,其在pycharm中的表示形式为:

Index={list}<class ‘list’>: [4, 135, 113, 34, 47, 140, 87, 0, 59, 33, 144, 43, 83, 133, 1, 78] #That’s explained everything!
self={_SingleProcessDataLoaderlter}<torch.utils.data.dataloader._SingleProcessDataLoaderIter object at 0x000001F11BF6A7C8>

  • 可以看到!index是个列表,里面装有shuffle后的batch_size大小的样本索引序列!有了Index之后,代码中会进入一个dataset_fetcher.fetch()函数
  • 点击功能区中的step_into,进入到一个_MapDatasetFetcher()类当中,在这个类里面实现了具体的数据读取,具体代码如下。
class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]  #调用了dataset,通过一系列的data拼接成一个list;
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)
  • 代码中调用了dataset,通过输入一个索引index返回一个data,通过一系列的data拼接成一个list;
  • 采用步进查看一下这个过程,代码跳转到mt_dataset.py中的类RMBdataset()中的__getitem__()函数中,所以dataset最重要最核心的就是__getitem__()函数;了解这些底层,有助于以后快速上手大的project,看懂数据集载入,数据处理的操作。
def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')     # 0~255

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label
  • 这里已经实现了data_info()函数,即对数据进行了初步的读取,已经得到了图片的路径和标签的列表了回答了问题2!),再把index相应的值读出来即可(
    关于如何加载自定义数据集见这个博客,这里就解释了博客里讲的为什么要分两步,第一步制作标签索引的txt文件,第二步写Dataset类的getitem函数);然后通过Image.open实现了一个数据的读取(回答了问题3!)

  • 之后点击step_out跳出该函数,会返回fetch()函数中;

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

  • 在fetch() return 的时候会进入一个collate_fn(),它是数据的整理器,会将我们读取到的16个数据整理出一个batch的形式;

通过以上的分析,可以回答一开始提出的数据读取的三个问题:1、读哪些数据;2、从哪读数据;3、怎么读数据;
在这里插入图片描述
(1)从代码中可以发现,index是从sampler.py中输出的一个列表,所以读哪些数据是由sampler得到的;
(2)从代码中看,是从Dataset中的参数data_dir,告诉我们pytorch是从硬盘中的哪一个文件夹获取数据;

train_data = MyDataset(txt='../gender/train1.txt',type = "train", transform=transform_train) #Dataset类是自己写的,传进去的data_dir即"txt='../gender/train1.txt' "参数
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=train_sampler)

(3)从代码中可以发现,pytorch是从Dataset的getitem()中具体实现的,根据索引去读取数据;

Dataloader读取数据很复杂,需要经过四五个函数的跳转才能最终读取数据
为了简单,通过流程图来对数据读取机制有一个简单的认识;

在这里插入图片描述

总结精华:

  • 简单描述一下流程图,首先在for循环中去使用DataLoader,进入DataLoader之后是否采用多进程进入DataLoaderlter,进入DataLoaderIter之后会使用sampler去获取Index,拿到索引之后传输到DatasetFetcher,在DatasetFetcher中会调用Dataset,Dataset根据给定的Index以及在getitem中加载了索引文件txt中全部的数据集的图片路径和标签,读取一个batch_size大小的Img和Label数据之后,通过一个collate_fn将数据进行整理,整理成batch_Data的形式,接着就可以输入到模型中训练;( 一次dataloader只读取一个batch大小的数据!非全部数据集。在一个epoch中通过for循环进行iteration次dataloader的
  • 可以通过简单的例子来验证是否是这样,,看看他的代码实现是怎样的。

学习完这里真的太不容易了,但搞清楚一个重要的事情的来龙去脉还是很有成就感的!继续加油!

by 小李

如果你坚持到这里了,请一定不要停,山顶的景色更迷人!好戏还在后面呢。加油!
欢迎交流学习和批评指正,你的点赞和留言将会是我更新的动力!谢谢😃

  • 20
    点赞
  • 64
    收藏
    觉得还不错? 一键收藏
  • 5
    评论
在Python中,dataset是一个用于获取数据和标签的类。它提供了两个主要功能:一是获取每个数据和其对应的标签,二是确定数据集的总大小。\[2\]在给定的代码中,MyData是一个继承自torch.utils.data.Dataset的自定义数据集类。它通过重写__init__、__getitem__和__len__方法来实现这两个功能。__init__方法初始化了数据集的根目录和标签目录,并获取了所有图像的路径。__getitem__方法根据给定的索引返回对应的图像和标签。__len__方法返回数据集的总大小。\[1\] 此外,还有一个与dataset相关的类叫做dataloaderdataloader用于将dataset中的数据按照指定的batch size进行分批加载。它可以将dataset中的数据流动起来,实现批量输出。\[3\]在给定的代码中,train_dataset是由ants_dataset和bees_dataset拼接而成的数据集。可以使用len(train_dataset)命令在Python控制台中查看train_dataset数据集中的元素数量。train_dataset\[230\]可以获取train_dataset中索引为230的元素,其中包含图像和标签。img.show()可以显示该图像。\[1\] #### 引用[.reference_title] - *1* *2* [PyTorch中如何读取数据(Dataset类的使用)](https://blog.csdn.net/m0_51816252/article/details/124960748)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] - *3* [Pytorch Dataset类的使用(个人学习笔记)](https://blog.csdn.net/weixin_46355597/article/details/129316051)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^control_2,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值