引言
这篇文章是上一篇文章的一个补充,主要是从数据存储的角度出发去探索整个pytorch的数据加载方式,最终会追朔到数据文件的读写。本文以MNIST数据集举例,从源代码分析了解整个流程。
第一步定义数据集
train_data = torchvision.datasets.MNIST(
root='./MINIST', # 数据集的位置
train=True, # 如果为True则为训练集,如果为False则为测试集
transform=torchvision.transforms.ToTensor(), # 将图片转化成取值[0,1]的Tensor用于网络处理
download=False
)
上面代码是初始化MNIST数据集的代码,MNIST是pytorch的torch.util.data.Dataset的实现。
class MNIST(VisionDataset):
class VisionDataset(data.Dataset):
可以看到它实现了Dataset中的中的__getitem__()接口,具体实现如下:
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self.train = train # training set or test set
if self._check_legacy_exist():
self.data, self.targets = self._load_legacy_data()
return
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
self.data, self.targets = self._load_data()
def __getitem__(self, index: int) -> Tuple[Any, Any]:
img, target = self.data[index], int(self.targets[index])
img = Image.fromarray(img.numpy(), mode="L")
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
这个接口的实现非常重要,这是pytorch暴露出来的提供给程序员自定义文件加载到内存的接口。最后我们会发现数据集中的每一个{x:y}数据对都是从这个方法中加载的。
第二步通过DataLoader加载数据集
train_loader=Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
将之前构建的dataset传入,DataLoader实现了__iter__方法产生迭代器,对应的迭代器_BaseDataLoaderIter实现了__next__()方法。
for step, (b_x, b_y) in enumerate(train_loader):
因此在for循环提取数据时,会调用dataLoader的__iter__方法创建DataLoader的迭代器,随后在每一次循环调用__next__方法。
如果对python中的迭代器和生成器不了解的话,可以参考这篇文章https://zhuanlan.zhihu.com/p/341439647
BaseDataLoaderIter中的__next_()方法会调用_next_data()方法,对于单进程和多进程的_next_data()方法实现的方式不同。多进程的实现较为复杂我将专门出一期文章进行讨论。
一下是单进程的_next_data()函数
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
return data
_next_data()函数,通过传入index,调用fetch函数获得数据,对于fetch函数,map类型数据集和iter类型数据集实现方式不同。
map类型数据集
def fetch(self, possibly_batched_index):
if self.auto_collation:
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
data = self.dataset.__getitems__(possibly_batched_index)
else:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
由上述代码可以看到最终调用的是dataset的__getitem__()函数,getitem()函数实现了利用下标索引来访问数据的作用,以MNIST数据集举例,MNIST数据集在初始化时,将数据加载到内存。
def __init__(
self,
root: str,
train: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
self.train = train # training set or test set
if self._check_legacy_exist():
self.data, self.targets = self._load_legacy_data()
return
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it")
self.data, self.targets = self._load_data()
def _load_data(self):
image_file = f"{'train' if self.train else 't10k'}-images-idx3-ubyte"
data = read_image_file(os.path.join(self.raw_folder, image_file))
label_file = f"{'train' if self.train else 't10k'}-labels-idx1-ubyte"
targets = read_label_file(os.path.join(self.raw_folder, label_file))
return data, targets
因此对于MNIST数据集,getitem(index)函数实现了对内存中数据访问的方式。不难理解,这个接口的作用就是定义数据的访问方式,或者说整个dataset就是在定义数据访问方式。
_load_data()函数,通过文件名使用read_sn3_pascalvincent_tensor()用于读取 Pascal Vincent 编码的序列化神经网络(SN3)格式的张量数据。
def read_label_file(path: str) -> torch.Tensor:
x = read_sn3_pascalvincent_tensor(path, strict=False)
if x.dtype != torch.uint8:
raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
if x.ndimension() != 1:
raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}")
return x.long()
def read_image_file(path: str) -> torch.Tensor:
x = read_sn3_pascalvincent_tensor(path, strict=False)
if x.dtype != torch.uint8:
raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}")
if x.ndimension() != 3:
raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}")
return x
tip:从数据存储的角度分析,这种通过直接文件名的文件访问方式,是很难在数据加载效率方面进行优化改进的,最多在数据容灾层面上做优化。
当在初始化过程中将所有数据解码加载到内存过后,就需要利用dataLoader调用__getitem__(index)函数取出数据。index由dataLoader的迭代器中的_next_index()方法获得。
def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration
_sampler_iter是Sampler(数据分配对象的迭代器),所以数据的索引是通过数据分配器来获得的。在MNIST数据集中采用的是RandomSampler,所以它的下标是范围内随机数。
self._sampler_iter = iter(self._index_sampler)
Iter类型数据集
class _IterableDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super().__init__(dataset, auto_collation, collate_fn, drop_last)
self.dataset_iter = iter(dataset)
self.ended = False
def fetch(self, possibly_batched_index):
if self.ended:
raise StopIteration
if self.auto_collation:
data = []
for _ in possibly_batched_index:
try:
data.append(next(self.dataset_iter))
except StopIteration:
self.ended = True
break
if len(data) == 0 or (
self.drop_last and len(data) < len(possibly_batched_index)
):
raise StopIteration
else:
data = next(self.dataset_iter)
return self.collate_fn(data)
迭代器的创建是在dataLoader迭代器创建的时候创建。以一个自定义的数据集举例:
class MyIterableDataset(IterableDataset):
def __init__(self, file_path):
self.file_path = file_path
def __iter__(self):
with open(self.file_path, 'r') as file_obj:
for line in file_obj: # 更多操作在这里完成
line_data = line.strip('\n').split(',')
yield line_data
每次next 会运行到yield 暂停点,这样一行一行获取数据。
data由迭代器获得,通过next方法不断迭代运行到暂停点以获取数据。
总结
整体的数据访问流程如下:
For循环->loader.iter–> self._get_iterator() --> class _SingleProcessDataLoaderIter --> class _BaseDataLoaderIter --> next() --> self._next_data() --> self.next_index() -->next(self.sampler_iter) 即 next(iter(self.index_sampler)) --> 获得 index --> self.dataset_fetcher.fetch(index) -->dataset.getitem () or dataset.iter().
其中dataset.init()、Dataset.getitem()、dataset.iter()
由AI开发人员定义,对于MNIST手写数字识别任务,通过在__init()函数中利用文件名将文件直接加载到内存,然后dataLoader调用__getitem()访问数据。
对于多进程的数据处理流程,索引index的获取需要通过多线程对索引队列的操作实现,在获取之后,仍然是通过self.dataset_fetcher.fetch(index) ,调用__getitem_()获取一个数据,放入到结果队列中。本质上数据的加载过程和单进程没有区别。