PyTorch Week 2——Dataloader与Dataset

系列文章目录

PyTorch Week 1



前言

本文记录在深度之眼PyTorch基础第二周课程学习的知识

一、数据读取

1 一个人民币二分类任务

DataLoader

torch.utils.data.DataLoader(dataset,#Dataset类,决定数据从哪读取以及如何读取
							batch_size=1,#
							shuffle=False,#
							sampler=None,
							batch_sampler=None,
							num_workers=0,#多进程
							collate_fn=None,
							pin_memory=False,
							drop_last=False,#是否舍弃最后一批数据
							timeout=0,
							worker_init_fn=None,
							multiprocessing_context=None)

Dataset

class Dataset(object):#重写Dataset以用于自己的问题
	def __getitem__(self, index):#接收一个索引返回一个样本
		raise NotImplementedError
		
	def __add__(self,other):
		return ConcatDataset([self, other])

数据读取:

  1. 读那些数据?
  2. 从哪读数据?
  3. 怎么读数据?

代码调试,理解DataLoader的数据读取机制

  1. 设置断点,步入train_loader
# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

for epoch in range(MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    for i, data in enumerate(train_loader):#设置断点,步入train_loader

2 继续步入,单进程情况下,使用_get_iterator()获取数据

    def __iter__(self) -> '_BaseDataLoaderIter':
        # When using a single worker the returned iterator should be
        # created everytime to avoid reseting its state
        # However, in the case of a multiple workers iterator
        # the iterator is only created once in the lifetime of the
        # DataLoader object so that workers can be reused
        if self.persistent_workers and self.num_workers > 0:
            if self._iterator is None:
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else:
            return self._get_iterator()# 单进程,使用_get_iterator()获取数据
    def _get_iterator(self) -> '_BaseDataLoaderIter':
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)#单进程获取数据
        else:
            self.check_worker_number_rationality()
            return _MultiProcessingDataLoaderIter(self)

3 单进程获取数据函数,首先使用_next_index()函数获取index列表,再通过_dataset_fetcher.fetch(index)获取index对应的data。

class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)

    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data

3 步入self._next_index()中,进入Iter类,通过_index_sampler获取index

class _BaseDataLoaderIter(object):
    def __init__(self, loader: DataLoader) -> None:
        
    def __iter__(self) -> '_BaseDataLoaderIter':
        return self

    def _reset(self, loader, first_iter=False):
        self._sampler_iter = iter(self._index_sampler)
        self._num_yielded = 0
        self._IterableDataset_len_called = loader._IterableDataset_len_called

    def _next_index(self):
        return next(self._sampler_iter) 

步入._sampler_iter,在sampler.py文件内,在这里生成了index

    def __iter__(self) -> Iterator[List[int]]:
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []#返回的index
        if len(batch) > 0 and not self.drop_last:
            yield batch

4 self._dataset_fetcher.fetch中,fetchh函数用于获取index对应的数据,返回data

class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]#在这里调用了dataset来获取数据
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)

5 步入dataset,在自定义的RMBDataset中,__getitem__用于根据传递进来的index读取数据

class RMBDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        rmb面额分类任务的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform,数据预处理
        """
        self.label_name = {"1": 0, "100": 1}
        self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.transform = transform

    def __getitem__(self, index):#按照传递进来的index的索引获取图片路径和标签
    	path_img, label = self.data_info[index] 
        img = Image.open(path_img).convert('RGB')     # 0~255根据图片路径读取图片数据

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.data_info)
  • 形成的data数据格式如下,列表,包含两个元素,第一个元素是图片数据,shape = (16, 3, 32, 32),第二个元素是标签,shape = (16,)
    

在这里插入图片描述

简单来说,Sampler
函数作为采样器,提供一个index列表,决定了这个batch_size数据的索引,

总结

以上就是本节内容,主要是对于pytorch的Dataloader和Dataset模块的机制认识。从数据读那些?从哪读?和怎么读?三个方面去理解代码。

Dataloader通过——>_next_data——>_next_index——>_sampler_iter——>return batch作为index列表,以上步骤通过sampler获取一个index,fetch调用index解决读那些数据的问题。

再通过——>.dataset_fetcher.fetch(index)——>self.dataset[idx]——>RMBDataset里的__getitem_——>self.data_info按照index读取图片路径列表和标签列表——>Image.open().convert()按照图片路径读取图片数据,return 图片数据 标签
dataset中,文件路径解决从哪读,__getitem__解决怎么读的问题

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值