pytorch学习

本文详细介绍了在PyTorch中训练深度学习模型时数据读取的过程,涉及DataLoader的参数如batch_size和shuffle,以及数据读取的步骤,包括通过Sampler获取index,从Dataset中使用__getitem__方法读取数据,最后经过collate_fn形成批次数据。
摘要由CSDN通过智能技术生成

pytorch——数据之读取数据

深度学习中训练一个模型一般都是采用如下的几个模块:

数据——模型——损失函数——优化器——迭代训练

数据

数据一般包含:数据收集、数据划分、数据读取、数据预处理

其中数据读取一般是采用DataLoader功能:构建可迭代的数据装载器
常见的几个参数

DataLoader(
					dataset,
					batch_size=1
					shuffle=False,
					num_workers=0,
					drop_last)

其中drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据
常见的几个变量名:
epoch:所有训练样本都已经输送到模型中,称为epoch
iteration:一批样本输入到模型中,称为一个iteration
batchsize:批大小,决定一个epoch有多少个iteration
例如:样本数有80个,每次输送8个样本,分10次输送。其中batchsize:8,iteration:10

数据读取

举个例子:
加载数据一般从这里开始,然后逐一debug。
第一步:

for i, data in enumerate(train_loader)

第二步:进入dataloader.py

def __iter__(self) -> '_BaseDataLoaderIter'

此段代码的主要目的:返回一个迭代器对象,用于在数据加载过程中逐个获取数据
代码包含了一个条件语句来处理不同的情况:

  • 如果DataLoader使用了多个工作进程且设置了persistent_workers参数为True,则创建一个持续的迭代器对象。迭代器对象只创建一次,可以重复使用。如果为创建迭代器,则通过_get_iterator()方法创建一个新的迭代器。如果存在了,通过_reset()方法重置迭代器对象的状态。
  • 如果DataLoader未使用对个工作进程或者persistent_workers参数为False,则通过_get_iterator()创建一个新的迭代器

第三步:假设采用_SingleProcessDataLoaderIter(),则进入
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter)

    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
        return data

发现先获取index,即:用sampler.py获取index
接着根据index,来获取data,进而对数据进行整理

class _MapDatasetFetcher(_BaseDatasetFetcher):
    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
                data = self.dataset.__getitems__(possibly_batched_index)
            else:
                data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

进而进入

    def __getitem__(self, index):
        path_img, label = self.data_info[index]

自此读取数据

根据上述可知:

数据读取

1.读哪些数据

sampler输出的index

2.从哪读数据

DataSet中的data_dir

3.怎么读数据

DataSet中的getitem

大致流程:DataLoader——DataLoaderIter——Sampler——index——DatasetFetcher——Dataset——getitem——Img、Label——collate_fn——BatchData

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值