文章目录
1 Dataloader与Dateset
1.1 Dataloader
- 功能:构建可迭代的数据装载器
- dataset:Dataset类,决定数据从哪读取以及如何读取
- batchsize:批大小
- num_works:是否多进程读取数据
- shuffle:每个epoch是否乱序
- drop_last:样本不能被batchsize整除时,是否舍弃最后一批数据
1.2 Dataset
- 功能:Dataset是一个抽象类,需要自定义Dataset继承,并且复写__getitem__()
- getitem():接收一个索引,返回一个样本以及标签
- 例如:
class DogCatDataset(Dataset):
def __init__(self, data_dir, transform=None):
"""
分类任务的Dataset
:param data_dir: str, 数据集所在路径
:param transform: torch.transform,数据预处理
"""
self.data_info = self.get_img_info(data_dir) # data_info存储所有图片路径和标签,在DataLoader中通过index读取样本
self.transform = transform
def __getitem__(self, index):
path_img, label = self.data_info[index]
img = Image.open(path_img).convert('RGB') # 0~255
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.data_info)
@staticmethod
def get_img_info(data_dir):
data_info = list()
img_names = os.listdir(data_dir)
# 遍历图片
for i in range(len(img_names)):
img_name = img_names[i]
path_img = os.path.join(data_dir, img_name)
if "cat" in img_name:
label = 0
else:
label = 1
data_info.append((path_img, int(label)))
return data_info
2 如何使用Dataloader与Dataset读取数据
如何使用Dataloader与Dataset读取数据,并且使用此数据进行训练。
代码示例:
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
valid_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std),
])
# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)
# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
# 遍历数据中每个batch
for i, data in enumerate(train_loader):
"""
此部分通过 遍历train_loader可以获取到所有的batch中数据
注:如模型训练中设置了epoch,还需要在此循环外套一个epoch的循环
"""
3 数据读取机制
3.1 数据获取机制流程
此部分通过对上段代码中的loader循环打上断点进行步进调试,观察其执行机制。
- 第一步:for i, data in enumerate(train_loader):
- 第二步:进入了dataloader类中的__iter__()函数,此函数主要返回一个迭代器,但是进行了判断,如何返回迭代器
def __iter__(self) -> '_BaseDataLoaderIter':
# When using a single worker the returned iterator should be
# created everytime to avoid reseting its state
# However, in the case of a multiple workers iterator
# the iterator is only created once in the lifetime of the
# DataLoader object so that workers can be reused
"""
#当使用单个工作程序时,返回的迭代器应为
#每次创建以避免重置其状态
#然而,在多工作者迭代器的情况下
#迭代器在对象的生命周期中仅创建一次
#DataLoader对象,以便可以重用Worker
"""
if self.persistent_workers and self.num_workers > 0:
if self._iterator is None:
self._iterator = self._get_iterator()
else:
self._iterator._reset(self)
return self._iterator
else:
return self._get_iterator()
- 第三步:进入了dataloader类中的_get_iterator(self)函数,通过进程数量判断,得到一个(或者多个)迭代器。
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
- 第四步: 进入_SingleProcessDataLoaderIter类中的__init__()函数
"""
继承自_BaseDataLoaderIter的超类
"""
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
- 第五步:进入_BaseDataLoaderIter的类的__init__()函数,初始化超类的一些基本信息
class _BaseDataLoaderIter(object):
def __init__(self, loader: DataLoader) -> None:
self._dataset = loader.dataset
self._dataset_kind = loader._dataset_kind
self._IterableDataset_len_called = loader._IterableDataset_len_called
self._auto_collation = loader._auto_collation
self._drop_last = loader.drop_last
self._index_sampler = loader._index_sampler
self._num_workers = loader.num_workers
self._prefetch_factor = loader.prefetch_factor
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
self._timeout = loader.timeout
self._collate_fn = loader.collate_fn
self._sampler_iter = iter(self._index_sampler)
self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
self._persistent_workers = loader.persistent_workers
self._num_yielded = 0
self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__)
- 第六步:在BaseDataLoaderIter类中的__init__()函数中执行到第四行,进入DataLoader类,_auto_collation()函数。这个函数返回了一个对batch_sampler判断。
sampler主要是输出每个batch需要数据集的索引index
class DataLoader(Generic[T_co]):
def _auto_collation(self):
return self.batch_sampler is not None
- 第七步:在BaseDataLoaderIter类中的__init__()函数中执行到第六行,进入DataLoader类,_index_sampler()函数,这个函数主要是要生成一个sampler类了。
class DataLoader(Generic[T_co]):
def _index_sampler(self):
# The actual sampler used for generating indices for `_DatasetFetcher`
# (see _utils/fetch.py) to read data at each time. This would be
# `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
# We can't change `.sampler` and `.batch_sampler` attributes for BC
# reasons.
if self._auto_collation:
return self.batch_sampler
else:
return self.sampler
- 第八步:执行完BaseDataLoaderIter类中的__init__()函数,返回_SingleProcessDataLoaderIter类,执行__init__()函数的最后一行,返回到_SingleProcessDataLoaderIter类,执行其最后一行代码,进入_DatasetKind类,create_fetcher()函数
这个函数返回的fetcher的类,主要是对数据集,sampler的应用,将数据集分成N个batch,并判断drop_last的T F,是否要保留最后一个不满足batch_size大小的batch,最后把所有batch组装在一起。
class _DatasetKind(object):
Map = 0
Iterable = 1
@staticmethod
def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
if kind == _DatasetKind.Map:
return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
else:
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
- 第九步:执行到上一步的 return MapDatasetFetcher,进入到_MapDatasetFetcher类。
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
- 第十步:进入_BaseDatasetFetcher类,init函数
class _BaseDatasetFetcher(object):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
self.dataset = dataset
self.auto_collation = auto_collation
self.collate_fn = collate_fn
self.drop_last = drop_last
- 第十一步:返回到DataLoader类,_get_iterator函数,(即返回到了第三步,第三步调用的类,函数都以及完成)
def _get_iterator(self) -> '_BaseDataLoaderIter':
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIter(self)
- 第十二步:返回到DataLoader类,iter函数,同理十一步,返回到了第二步。
- 第十三步:返回到了for i, data in enumerate(train_loader):
意味着以及对数据集的读取,做好了各种初始化,sampler,dataset等参数以及准备完毕。
- 第十四步::进入_BaseDataLoaderIter类,next函数
def __next__(self) -> Any:
with torch.autograd.profiler.record_function(self._profile_name):
if self._sampler_iter is None:
self._reset()
data = self._next_data()
self._num_yielded += 1
if self._dataset_kind == _DatasetKind.Iterable and \
self._IterableDataset_len_called is not None and \
self._num_yielded > self._IterableDataset_len_called:
warn_msg = ("Length of IterableDataset {} was reported to be {} (when accessing len(dataloader)), but {} "
"samples have been fetched. ").format(self._dataset, self._IterableDataset_len_called,
self._num_yielded)
if self._num_workers > 0:
warn_msg += ("For multiprocessing data-loading, this could be caused by not properly configuring the "
"IterableDataset replica at each worker. Please see "
"https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.")
warnings.warn(warn_msg)
return data
- 第十五步:当执行到上一步的data = self._next_data(),进入_SingleProcessDataLoaderIter类,_next_data的函数
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
- 第十六步:进入_BaseDataLoaderIter类,_next_index的函数
def _next_index(self):
return next(self._sampler_iter) # may raise StopIteration
- 第十七步::返回到_SingleProcessDataLoaderIter类,_next_data函数
此时返回给了index 一组batch应该在总数据集上的索引。
- 第十八步:进入_MapDatasetFetcher类,fetch函数
- 第十九步:进入到dataset类,__getitem__函数。
此函数就是通过索引值,在总的数据集上获取数据,fetch函数就是通过上一步返回来的索引值,遍历获取数据(调用dataset的getitem函数)。
- 第二十步:所有索引对应的数据获取完毕
- 第二十一步:进入default_collate类
此类的作用,可以通过注释看出,就算把一组batch的数据,封装成一个tensor返回,即原本一张图片的数据是三通道分别代表 w h c ,那么一组batch封装成一个tensor就是 b w h c
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, collections.abc.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError('each element in list of batch should be of equal size')
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
- 第二十二步:返回到_SingleProcessDataLoaderIter类,_next_data函数,即返回到第十七步,代表着这个batch的数据获取完毕
- 第二十三步:返回到_BaseDataLoaderIter类,next函数,即返回到第十四步,代表着将获取到的数据返回到他,最后return data。至此一个batch的数据获取完毕!!!
3.2 数据获取机制总结:
- 读那些数据:由sampler输出的index决定
- 从哪儿读数据:由dataset中的data_dir决定
- 怎么读数据:由dataset中的getitem决定