Pytorch:数据读取机制Dataloader与Dataset

1 Dataloader与Dateset

1.1 Dataloader

  • 功能:构建可迭代的数据装载器
  • dataset:Dataset类,决定数据从哪读取以及如何读取
  • batchsize:批大小
  • num_works:是否多进程读取数据
  • shuffle:每个epoch是否乱序
  • drop_last:样本不能被batchsize整除时,是否舍弃最后一批数据

1.2 Dataset

  • 功能:Dataset是一个抽象类,需要自定义Dataset继承,并且复写__getitem__()
  • getitem():接收一个索引,返回一个样本以及标签
  • 例如:
class DogCatDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        """
        分类任务的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform,数据预处理
        """
        self.data_info = self.get_img_info(data_dir)  # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
        self.transform = transform

    def __getitem__(self, index):
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')     # 0~255

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.data_info)

    @staticmethod
    def get_img_info(data_dir):
        data_info = list()
        img_names = os.listdir(data_dir)


        # 遍历图片
        for i in range(len(img_names)):
            img_name = img_names[i]
            path_img = os.path.join(data_dir, img_name)
            if "cat" in img_name:
                label = 0
            else:
                label = 1
            data_info.append((path_img, int(label)))

        return data_info

2 如何使用Dataloader与Dataset读取数据

如何使用Dataloader与Dataset读取数据,并且使用此数据进行训练。

代码示例:

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# 遍历数据中每个batch
for i, data in enumerate(train_loader):
	"""
	此部分通过 遍历train_loader可以获取到所有的batch中数据
	注:如模型训练中设置了epoch,还需要在此循环外套一个epoch的循环
	"""

3 数据读取机制

3.1 数据获取机制流程

此部分通过对上段代码中的loader循环打上断点进行步进调试,观察其执行机制。

  • 第一步:for i, data in enumerate(train_loader):
  • 第二步:进入了dataloader类中的__iter__()函数,此函数主要返回一个迭代器,但是进行了判断,如何返回迭代器
    def __iter__(self) -> '_BaseDataLoaderIter':
        # When using a single worker the returned iterator should be
        # created everytime to avoid reseting its state
        # However, in the case of a multiple workers iterator
        # the iterator is only created once in the lifetime of the
        # DataLoader object so that workers can be reused
        """
        #当使用单个工作程序时,返回的迭代器应为
		#每次创建以避免重置其状态
		#然而,在多工作者迭代器的情况下
		#迭代器在对象的生命周期中仅创建一次
		#DataLoader对象,以便可以重用Worker
        """
        if self.persistent_workers and self.num_workers > 0:
            if self._iterator is None:
                self._iterator = self._get_iterator()
            else:
                self._iterator._reset(self)
            return self._iterator
        else:
            return self._get_iterator()
  • 第三步:进入了dataloader类中的_get_iterator(self)函数,通过进程数量判断,得到一个(或者多个)迭代器。
 def _get_iterator(self) -> '_BaseDataLoaderIter':
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            self.check_worker_number_rationality()
            return _MultiProcessingDataLoaderIter(self)
  • 第四步: 进入_SingleProcessDataLoaderIter中的__init__()函数
"""
继承自_BaseDataLoaderIter的超类
"""
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
  • 第五步:进入_BaseDataLoaderIter的的__init__()函数,初始化超类的一些基本信息
class _BaseDataLoaderIter(object):
    def __init__(self, loader: DataLoader) -> None:
        self._dataset = loader.dataset
        self._dataset_kind = loader._dataset_kind
        self._IterableDataset_len_called = loader._IterableDataset_len_called
        self._auto_collation = loader._auto_collation
        self._drop_last = loader.drop_last
        self._index_sampler = loader._index_sampler
        self._num_workers = loader.num_workers
        self._prefetch_factor = loader.prefetch_factor
        self._pin_memory = loader.pin_memory and torch.cuda.is_available()
        self._timeout = loader.timeout
        self._collate_fn = loader.collate_fn
        self._sampler_iter = iter(self._index_sampler)
        self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
        self._persistent_workers = loader.persistent_workers
        self._num_yielded = 0
        self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__)
  • 第六步:在BaseDataLoaderIter类中的__init__()函数中执行到第四行,进入DataLoader类,_auto_collation()函数。这个函数返回了一个对batch_sampler判断。

sampler主要是输出每个batch需要数据集的索引index

class DataLoader(Generic[T_co]):
    def _auto_collation(self):
        return self.batch_sampler is not None
  • 第七步:在BaseDataLoaderIter类中的__init__()函数中执行到第六行,进入DataLoader类,_index_sampler()函数,这个函数主要是要生成一个sampler类了。
class DataLoader(Generic[T_co]):
    def _index_sampler(self):
        # The actual sampler used for generating indices for `_DatasetFetcher`
        # (see _utils/fetch.py) to read data at each time. This would be
        # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
        # We can't change `.sampler` and `.batch_sampler` attributes for BC
        # reasons.
        if self._auto_collation:
            return self.batch_sampler
        else:
            return self.sampler
  • 第八步:执行完BaseDataLoaderIter类中的__init__()函数,返回_SingleProcessDataLoaderIter,执行__init__()函数的最后一行,返回到_SingleProcessDataLoaderIter类,执行其最后一行代码,进入_DatasetKind,create_fetcher()函数

这个函数返回的fetcher的类,主要是对数据集,sampler的应用,将数据集分成N个batch,并判断drop_last的T F,是否要保留最后一个不满足batch_size大小的batch,最后把所有batch组装在一起。

class _DatasetKind(object):
    Map = 0
    Iterable = 1

    @staticmethod
    def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
        if kind == _DatasetKind.Map:
            return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
        else:
            return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)

  • 第九步:执行到上一步的 return MapDatasetFetcher,进入到_MapDatasetFetcher
class _MapDatasetFetcher(_BaseDatasetFetcher):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)

    def fetch(self, possibly_batched_index):
        if self.auto_collation:
            data = [self.dataset[idx] for idx in possibly_batched_index]
        else:
            data = self.dataset[possibly_batched_index]
        return self.collate_fn(data)
  • 第十步:进入_BaseDatasetFetcher类,init函数
class _BaseDatasetFetcher(object):
    def __init__(self, dataset, auto_collation, collate_fn, drop_last):
        self.dataset = dataset
        self.auto_collation = auto_collation
        self.collate_fn = collate_fn
        self.drop_last = drop_last
  • 第十一步:返回到DataLoader,_get_iterator函数,(即返回到了第三步,第三步调用的类,函数都以及完成)
    def _get_iterator(self) -> '_BaseDataLoaderIter':
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            self.check_worker_number_rationality()
            return _MultiProcessingDataLoaderIter(self)

  • 第十二步:返回到DataLoaderiter函数,同理十一步,返回到了第二步
  • 第十三步:返回到了for i, data in enumerate(train_loader):

意味着以及对数据集的读取,做好了各种初始化,sampler,dataset等参数以及准备完毕。

  • 第十四步::进入_BaseDataLoaderIternext函数
    def __next__(self) -> Any:
        with torch.autograd.profiler.record_function(self._profile_name):
            if self._sampler_iter is None:
                self._reset()
            data = self._next_data()
            self._num_yielded += 1
            if self._dataset_kind == _DatasetKind.Iterable and \
                    self._IterableDataset_len_called is not None and \
                    self._num_yielded > self._IterableDataset_len_called:
                warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
                            "samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
                                                                  self._num_yielded)
                if self._num_workers > 0:
                    warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
                                 "IterableDataset replica at each worker. Please see "
                                 "https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
                warnings.warn(warn_msg)
            return data
  • 第十五步:当执行到上一步的data = self._next_data(),进入_SingleProcessDataLoaderIter,_next_data的函数
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __init__(self, loader):
        super(_SingleProcessDataLoaderIter, self).__init__(loader)
        assert self._timeout == 0
        assert self._num_workers == 0

        self._dataset_fetcher = _DatasetKind.create_fetcher(
            self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
    def _next_data(self):
        index = self._next_index()  # may raise StopIteration
        data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
        if self._pin_memory:
            data = _utils.pin_memory.pin_memory(data)
        return data
  • 第十六步:进入_BaseDataLoaderIter,_next_index的函数
    def _next_index(self):
        return next(self._sampler_iter)  # may raise StopIteration
  • 第十七步::返回到_SingleProcessDataLoaderIter,_next_data函数

此时返回给了index 一组batch应该在总数据集上的索引。

在这里插入图片描述

  • 第十八步:进入_MapDatasetFetcher,fetch函数
    在这里插入图片描述
  • 第十九步:进入到dataset,__getitem__函数。

此函数就是通过索引值,在总的数据集上获取数据,fetch函数就是通过上一步返回来的索引值,遍历获取数据(调用dataset的getitem函数)。

  • 第二十步:所有索引对应的数据获取完毕
    在这里插入图片描述
  • 第二十一步:进入default_collate类

此类的作用,可以通过注释看出,就算把一组batch的数据,封装成一个tensor返回,即原本一张图片的数据是三通道分别代表 w h c ,那么一组batch封装成一个tensor就是 b w h c

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, collections.abc.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, collections.abc.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))
  • 第二十二步:返回到_SingleProcessDataLoaderIter,_next_data函数,即返回到第十七步,代表着这个batch的数据获取完毕
    在这里插入图片描述
  • 第二十三步:返回到_BaseDataLoaderIternext函数,即返回到第十四步,代表着将获取到的数据返回到他,最后return data。至此一个batch的数据获取完毕!!!

3.2 数据获取机制总结:

  • 读那些数据:由sampler输出的index决定
  • 从哪儿读数据:由dataset中的data_dir决定
  • 怎么读数据:由dataset中的getitem决定

Sequence diagram

DataLoader
DateLoaderIter
Sampler
Index
DataSetFetcher
Dateset
getitem
Img,Label
collate_fn
Batch_Data
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值