系列文章目录
PyTorch Week 2——Dataloader与Dataset
前言
本文记录在深度之眼PyTorch基础第二周课程学习的知识一、数据读取
1 一个人民币二分类任务
DataLoader
torch.utils.data.DataLoader(dataset,#Dataset类,决定数据从哪读取以及如何读取
batch_size=1,#
shuffle=False,#
sampler=None,
batch_sampler=None,
num_workers=0,#多进程
collate_fn=None,
pin_memory=False,
drop_last=False,#是否舍弃最后一批数据
timeout=0,
worker_init_fn=None,
multiprocessing_context=None)
Dataset
class Dataset(object):#重写Dataset以用于自己的问题
def __getitem__(self, index):#接收一个索引返回一个样本
raise NotImplementedError
def __add__(self,other):
return ConcatDataset([self, other])
数据读取:
- 读那些数据?
- 从哪读数据?
- 怎么读数据?
代码调试,理解DataLoader的数据读取机制
- 设置断点,步入train_loader
# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()
for epoch in range(MAX_EPOCH):
loss_mean = 0.
correct = 0.
total = 0.
net.train()
for i, data in enumerate(train_loader):#设置断点,步入train_loader
2 继续步入,单进程情况下,使用_get_iterator()获取数据
def __iter__(self) -> '_BaseDataLoaderIter':
# When using a single worker the returned iterator should be
# created everytime to avoid reseting its state
# However, in the case of a multiple workers iterator
# the iterator is only created once in the lifetime of the
# DataLoader object so that workers can be reused
if self.persistent_workers and self.num_workers > 0:
if self._iterator is None:
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else:
return self._get_iterator()# 单进程,使用_get_iterator()获取数据
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)#单进程获取数据
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
3 单进程获取数据函数,首先使用_next_index()函数获取index列表,再通过_dataset_fetcher.fetch(index)获取index对应的data。
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
3 步入self._next_index()中,进入Iter类,通过_index_sampler获取index
class _BaseDataLoaderIter(object):
def __init__(self, loader: DataLoader) -> None:
def __iter__(self) -> '_BaseDataLoaderIter':
return self
def _reset(self, loader, first_iter=False):
self._sampler_iter = iter(self._index_sampler)
self._num_yielded = 0
self._IterableDataset_len_called = loader._IterableDataset_len_called
def _next_index(self):
return next(self._sampler_iter)
步入._sampler_iter,在sampler.py文件内,在这里生成了index
def __iter__(self) -> Iterator[List[int]]:
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []#返回的index
if len(batch) > 0 and not self.drop_last:
yield batch
4 self._dataset_fetcher.fetch中,fetchh函数用于获取index对应的数据,返回data
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]#在这里调用了dataset来获取数据
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
5 步入dataset,在自定义的RMBDataset中,__getitem__用于根据传递进来的index读取数据
class RMBDataset(Dataset):
def __init__(self, data_dir, transform=None):
"""
rmb面额分类任务的Dataset
:param data_dir: str, 数据集所在路径
:param transform: torch.transform,数据预处理
"""
self.label_name = {"1": 0, "100": 1}
self.data_info = self.get_img_info(data_dir) # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
self.transform = transform
def __getitem__(self, index):#按照传递进来的index的索引获取图片路径和标签
path_img, label = self.data_info[index]
img = Image.open(path_img).convert('RGB') # 0~255根据图片路径读取图片数据
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.data_info)
-
形成的data数据格式如下,列表,包含两个元素,第一个元素是图片数据,shape = (16, 3, 32, 32),第二个元素是标签,shape = (16,)
简单来说,Sampler
函数作为采样器,提供一个index列表,决定了这个batch_size数据的索引,
总结
以上就是本节内容,主要是对于pytorch的Dataloader和Dataset模块的机制认识。从数据读那些?从哪读?和怎么读?三个方面去理解代码。
Dataloader通过——>_next_data——>_next_index——>_sampler_iter——>return batch作为index列表,以上步骤通过sampler获取一个index,fetch调用index解决读那些数据的问题。
再通过——>.dataset_fetcher.fetch(index)——>self.dataset[idx]——>RMBDataset里的__getitem_——>self.data_info按照index读取图片路径列表和标签列表——>Image.open().convert()按照图片路径读取图片数据,return 图片数据 标签
dataset中,文件路径解决从哪读,__getitem__解决怎么读的问题