pytorch——数据之读取数据
深度学习中训练一个模型一般都是采用如下的几个模块:
数据——模型——损失函数——优化器——迭代训练
数据
数据一般包含:数据收集、数据划分、数据读取、数据预处理
其中数据读取一般是采用DataLoader,功能:构建可迭代的数据装载器
常见的几个参数
DataLoader(
dataset,
batch_size=1
shuffle=False,
num_workers=0,
drop_last)
其中drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据
常见的几个变量名:
epoch:所有训练样本都已经输送到模型中,称为epoch
iteration:一批样本输入到模型中,称为一个iteration
batchsize:批大小,决定一个epoch有多少个iteration
例如:样本数有80个,每次输送8个样本,分10次输送。其中batchsize:8,iteration:10
数据读取
举个例子:
加载数据一般从这里开始,然后逐一debug。
第一步:
for i, data in enumerate(train_loader)
第二步:进入dataloader.py
def __iter__(self) -> '_BaseDataLoaderIter'
此段代码的主要目的:返回一个迭代器对象,用于在数据加载过程中逐个获取数据
代码包含了一个条件语句来处理不同的情况:
- 如果DataLoader使用了多个工作进程且设置了persistent_workers参数为True,则创建一个持续的迭代器对象。迭代器对象只创建一次,可以重复使用。如果为创建迭代器,则通过_get_iterator()方法创建一个新的迭代器。如果存在了,通过_reset()方法重置迭代器对象的状态。
- 如果DataLoader未使用对个工作进程或者persistent_workers参数为False,则通过_get_iterator()创建一个新的迭代器
第三步:假设采用_SingleProcessDataLoaderIter(),则进入
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter)
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
return data
发现先获取index,即:用sampler.py获取index
接着根据index,来获取data,进而对数据进行整理
class _MapDatasetFetcher(_BaseDatasetFetcher):
def fetch(self, possibly_batched_index):
if self.auto_collation:
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
data = self.dataset.__getitems__(possibly_batched_index)
else:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
进而进入
def __getitem__(self, index):
path_img, label = self.data_info[index]
自此读取数据
根据上述可知:
数据读取
1.读哪些数据
sampler输出的index
2.从哪读数据
DataSet中的data_dir
3.怎么读数据
DataSet中的getitem
大致流程:DataLoader——DataLoaderIter——Sampler——index——DatasetFetcher——Dataset——getitem——Img、Label——collate_fn——BatchData