torch.utils.data.DataLoader
功能
数据加载:DataLoader 会遍历整个数据集,将数据从原始的数据源(如文件、数据库或内存中的数据)加载到内存中。这一步骤通常涉及到读取图像、文本或其他类型的数据,并将其转换为 PyTorch 可以处理的格式。
数据预处理:在数据被加载到内存之后,DataLoader 通常会对数据进行预处理。这可能包括标准化、归一化、缩放、裁剪、颜色通道调整等操作,以确保数据在送入模型之前满足模型的要求。
批处理:在预处理完成后,DataLoader 会将数据组织成批次(batch)。每个批次包含多个数据样本,这些样本通常会在同一个训练循环中一起处理。批处理的大小是由 batch_size 参数控制的。
数据扩充:在某些情况下,为了增加训练数据的多样性,DataLoader 可能会在加载数据时应用数据扩充技术,如随机旋转、翻转、缩放等。
多线程处理:如果 DataLoader 的 num_workers 参数设置大于 0,那么数据加载过程将在多个线程中并行进行,这样可以加快数据加载的速度。但是,需要注意的是,在 Windows 系统上,由于多线程的实现与 Unix-like 系统不同,可能会遇到死锁或其他运行时错误。
内存管理:DataLoader 还负责管理内存使用情况,确保数据在训练过程中有效地加载和释放,以避免内存不足的问题。
迭代器生成:最终,DataLoader 生成一个迭代器,该迭代器可以在训练循环中迭代,每次迭代提供一个新的批次数据。
参数
1. dataset (Dataset)
Map-style:torch.utils.data.Dataset:子类必须重写__getitem__(),选择性重写__len__(),dataset[idx]可以得到idx-th个元素和对应标签。
Iterable-style:torch.utils.data.IterableDataset:适用流数据,适用于随机读取成本高甚至不可能的情况,以及批处理大小取决于所获取的数据的情况。子类必须重写__iter__(),希望做到不同进程配置不同副本,比如:
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # 单进程返回完整迭代器
iter_start = self.start
iter_end = self.end
else: # 多进程变量per_worker通过将数据集的范围(self.end - self.start)除以工作进程的数量(worker_info.num_workers)来计算,结果使用math.ceil()向上取整。
per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
return iter(range(iter_start, iter_end))
init()也都得自己写吧,才能定义__getitem__()等吧。
还有
torch.utils.data.TensorDataset(*tensors):每个tensor一条数据,照第0个维度索引
input_tensor = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32)
output_tensor = torch.tensor([0, 1, 0], dtype=torch.long)
dataset = torch.utils.data.TensorDataset(input_tensor, output_tensor) # 第0个维度要相同
torch.utils.data.StackDataset(*args, **kwargs):可以做多种数据集的集成
images = ImageDataset()
texts = TextDataset()
tuple_stack = StackDataset(images, texts)
tuple_stack[0] == (images[0], texts[0])
dict_stack = StackDataset(image=images, text=texts)
dict_stack[0] == {'image': images[0], 'text': texts[0]}
torch.utils.data.ConcatDataset(datasets):数据集拼接
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
cifar10_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
concat_dataset = ConcatDataset([mnist_train, cifar10_train])
torch.utils.data.Subset(dataset, indices):根据indices划分子集
dataset (Dataset) – 完整数据集
indices (sequence) – 为子集选择的索引
2. batch_size (int, optional) – how many samples per batch to load (default: 1).
3. shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
4. sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with len implemented. If specified, shuffle must not be specified.
iterable-style的数据集,他的迭代方式就靠数据集自己的定义啦,相当于Sampler or Iterable中的Iterable。
而Sampler or Iterable中的Sampler有很多种。
torch.utils.data.Sampler Sampler类中最基本的。
sampler = torch.utils.data.Sampler(dataset)
print(list(sampler)) # 输出的是一个索引列表
每个Sampler子类必须提供一个__iter__()方法,提供一种迭代数据集元素的索引或索引列表(批次)的方法,以及一个__len__()方法,返回返回的迭代器的长度。
torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None):随机顺序的,replacement决定是否放回,有放回的情况下,可以通过num_samplers决定抽样个数,generator 参数是一个随机数生成器,可以指定。
# 设置全局随机数生成器的种子
torch.manual_seed(42)
# 创建一个数据集
data_source = [...]
# 创建一个随机抽样器,并使用全局随机数生成器
sampler = RandomSampler(data_source)
torch.utils.data.BatchSampler(sampler, batch_size, drop_last):包装另一个采样器以产生一个小批量的索引。drop_last决定最后不够装进一个batch的元素是否去掉。
sampler = torch.utils.data.BatchSampler(randomsampler)
print(list(sampler)) # 输出的是一个有很多个batch_size大小的索引列表的列表
class InfiniteSampler(torch.utils.data.Sampler):
""" Infinite Sampler for PyTorch.
Inspired from : https://github.com/facebookresearch/DomainBed
Args:
sampler (torch.utils.data.Sampler): Sampler to be used for the infinite sampling.
"""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True: # 这个判断条件使不断迭代,结束条件是外部明显的结束,比如woods中,包装成dataloader后,5000个step,每个step会get_next_batch,即输出一个batch,执行下面循环中的一步,5000个step中会循环输出batch_sampler中的batch,且注意,batch_sampler中的sampler又是randomsampler,即每一次遍历都会重新分配(应该是)
for batch in self.sampler: # 循环输出batch_sampler中index的batch
yield batch
def __len__(self):
return len(self.sampler)
5. batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
其实上面的BatchSampler和InfiniteSampler都算此类,在Dataloader中也是作为batch_sampler参数。
6. num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
加载数据的子进程数。
7. collate_fn (Callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
把一个batch中数据整理成相同格式的。