Dataset+DataLoader加载数据流程代码详解

引言

这篇文章是上一篇文章的一个补充,主要是从数据存储的角度出发去探索整个pytorch的数据加载方式,最终会追朔到数据文件的读写。本文以MNIST数据集举例,从源代码分析了解整个流程。

第一步定义数据集

train_data = torchvision.datasets.MNIST(
   root='./MINIST',  # 数据集的位置
   train=True,  # 如果为True则为训练集,如果为False则为测试集
   transform=torchvision.transforms.ToTensor(),  # 将图片转化成取值[0,1]的Tensor用于网络处理
   download=False
)

上面代码是初始化MNIST数据集的代码,MNIST是pytorch的torch.util.data.Dataset的实现。

class MNIST(VisionDataset):
class VisionDataset(data.Dataset):

可以看到它实现了Dataset中的中的__getitem__()接口,具体实现如下:

def __init__(
   self,
   root: str,
   train: bool = True,
   transform: Optional[Callable] = None,
   target_transform: Optional[Callable] = None,
   download: bool = False,
) -> None:
   super().__init__(root, transform=transform, target_transform=target_transform)
   self.train = train  # training set or test set
   if self._check_legacy_exist():
       self.data, self.targets = self._load_legacy_data()
       return
   if download:
       self.download()
   if not self._check_exists():
       raise RuntimeError("Dataset not found. You can use download=True to download it")
   self.data, self.targets = self._load_data()
def __getitem__(self, index: int) -> Tuple[Any, Any]:
   img, target = self.data[index], int(self.targets[index])
   img = Image.fromarray(img.numpy(), mode="L")
   if self.transform is not None:
       img = self.transform(img)
   if self.target_transform is not None:
       target = self.target_transform(target)
   return img, target

这个接口的实现非常重要,这是pytorch暴露出来的提供给程序员自定义文件加载到内存的接口。最后我们会发现数据集中的每一个{x:y}数据对都是从这个方法中加载的。

第二步通过DataLoader加载数据集

train_loader=Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)

将之前构建的dataset传入,DataLoader实现了__iter__方法产生迭代器,对应的迭代器_BaseDataLoaderIter实现了__next__()方法。

for step, (b_x, b_y) in enumerate(train_loader):

因此在for循环提取数据时,会调用dataLoader的__iter__方法创建DataLoader的迭代器,随后在每一次循环调用__next__方法。
如果对python中的迭代器和生成器不了解的话,可以参考这篇文章https://zhuanlan.zhihu.com/p/341439647
BaseDataLoaderIter中的__next_()方法会调用_next_data()方法,对于单进程和多进程的_next_data()方法实现的方式不同。多进程的实现较为复杂我将专门出一期文章进行讨论。
一下是单进程的_next_data()函数

def _next_data(self):
   index = self._next_index()  # may raise StopIteration
   data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
   if self._pin_memory:
       data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
   return data

_next_data()函数,通过传入index,调用fetch函数获得数据,对于fetch函数,map类型数据集和iter类型数据集实现方式不同。

map类型数据集
def fetch(self, possibly_batched_index):
   if self.auto_collation:
       if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
           data = self.dataset.__getitems__(possibly_batched_index)
       else:
           data = [self.dataset[idx] for idx in possibly_batched_index]
   else:
       data = self.dataset[possibly_batched_index]
   return self.collate_fn(data)

由上述代码可以看到最终调用的是dataset的__getitem__()函数,getitem()函数实现了利用下标索引来访问数据的作用,以MNIST数据集举例,MNIST数据集在初始化时,将数据加载到内存。

def __init__(
   self,
   root: str,
   train: bool = True,
   transform: Optional[Callable] = None,
   target_transform: Optional[Callable] = None,
   download: bool = False,
) -> None:
   super().__init__(root, transform=transform, target_transform=target_transform)
   self.train = train  # training set or test set
   if self._check_legacy_exist():
       self.data, self.targets = self._load_legacy_data()
       return
   if download:
       self.download()
   if not self._check_exists():
       raise RuntimeError("Dataset not found. You can use download=True to download it")
   self.data, self.targets = self._load_data()
  
def _load_data(self):
   image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
   data = read_image_file(os.path.join(self.raw_folder, image_file))

   label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
   targets = read_label_file(os.path.join(self.raw_folder, label_file))

   return data, targets

因此对于MNIST数据集,getitem(index)函数实现了对内存中数据访问的方式。不难理解,这个接口的作用就是定义数据的访问方式,或者说整个dataset就是在定义数据访问方式。
_load_data()函数,通过文件名使用read_sn3_pascalvincent_tensor()用于读取 Pascal Vincent 编码的序列化神经网络(SN3)格式的张量数据。

def read_label_file(path: str) -> torch.Tensor:
    x = read_sn3_pascalvincent_tensor(path, strict=False)
    if x.dtype != torch.uint8:
        raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
    if x.ndimension() != 1:
        raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}")
    return x.long()


def read_image_file(path: str) -> torch.Tensor:
    x = read_sn3_pascalvincent_tensor(path, strict=False)
    if x.dtype != torch.uint8:
        raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
    if x.ndimension() != 3:
        raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}")
    return x

tip:从数据存储的角度分析,这种通过直接文件名的文件访问方式,是很难在数据加载效率方面进行优化改进的,最多在数据容灾层面上做优化。

当在初始化过程中将所有数据解码加载到内存过后,就需要利用dataLoader调用__getitem__(index)函数取出数据。index由dataLoader的迭代器中的_next_index()方法获得。

def _next_index(self):
   return next(self._sampler_iter)  # may raise StopIteration

_sampler_iter是Sampler(数据分配对象的迭代器),所以数据的索引是通过数据分配器来获得的。在MNIST数据集中采用的是RandomSampler,所以它的下标是范围内随机数。

self._sampler_iter = iter(self._index_sampler)
Iter类型数据集
class _IterableDatasetFetcher(_BaseDatasetFetcher):
   def __init__(self, dataset, auto_collation, collate_fn, drop_last):
       super().__init__(dataset, auto_collation, collate_fn, drop_last)
       self.dataset_iter = iter(dataset)
       self.ended = False
   def fetch(self, possibly_batched_index):
       if self.ended:
           raise StopIteration
       if self.auto_collation:
           data = []
           for _ in possibly_batched_index:
               try:
                   data.append(next(self.dataset_iter))
               except StopIteration:
                   self.ended = True
                   break
           if len(data) == 0 or (
               self.drop_last and len(data) < len(possibly_batched_index)
           ):
               raise StopIteration
       else:
           data = next(self.dataset_iter)
       return self.collate_fn(data)

迭代器的创建是在dataLoader迭代器创建的时候创建。以一个自定义的数据集举例:

class MyIterableDataset(IterableDataset):
    def __init__(self, file_path):
        self.file_path = file_path
    def __iter__(self):
        with open(self.file_path, 'r') as file_obj:
            for line in file_obj: # 更多操作在这里完成
                line_data = line.strip('\n').split(',')
                yield line_data

每次next 会运行到yield 暂停点,这样一行一行获取数据。
data由迭代器获得,通过next方法不断迭代运行到暂停点以获取数据。

总结

整体的数据访问流程如下:
For循环->loader.iter–> self._get_iterator() --> class _SingleProcessDataLoaderIter --> class _BaseDataLoaderIter --> next() --> self._next_data() --> self.next_index() -->next(self.sampler_iter) 即 next(iter(self.index_sampler)) --> 获得 index --> self.dataset_fetcher.fetch(index) -->dataset.getitem () or dataset.iter().
其中dataset.init()、Dataset.getitem()、dataset.iter()
由AI开发人员定义,对于MNIST手写数字识别任务,通过在__init
()函数中利用文件名将文件直接加载到内存,然后dataLoader调用__getitem
()访问数据。

对于多进程的数据处理流程,索引index的获取需要通过多线程对索引队列的操作实现,在获取之后,仍然是通过self.dataset_fetcher.fetch(index) ,调用__getitem_()获取一个数据,放入到结果队列中。本质上数据的加载过程和单进程没有区别。

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值