Dataloader

本文详细介绍了PyTorch中的DataLoader模块,包括其功能如数据加载、预处理、批处理、数据扩充、多线程优化、内存管理和迭代器生成。同时,阐述了关键参数如batch_size、shuffle、Sampler和相关类如TensorDataset、StackDataset和ConcatDataset的用法。
摘要由CSDN通过智能技术生成

torch.utils.data.DataLoader

功能

数据加载:DataLoader 会遍历整个数据集,将数据从原始的数据源(如文件、数据库或内存中的数据)加载到内存中。这一步骤通常涉及到读取图像、文本或其他类型的数据,并将其转换为 PyTorch 可以处理的格式。

数据预处理:在数据被加载到内存之后,DataLoader 通常会对数据进行预处理。这可能包括标准化、归一化、缩放、裁剪、颜色通道调整等操作,以确保数据在送入模型之前满足模型的要求。

批处理:在预处理完成后,DataLoader 会将数据组织成批次(batch)。每个批次包含多个数据样本,这些样本通常会在同一个训练循环中一起处理。批处理的大小是由 batch_size 参数控制的。

数据扩充:在某些情况下,为了增加训练数据的多样性,DataLoader 可能会在加载数据时应用数据扩充技术,如随机旋转、翻转、缩放等。

多线程处理:如果 DataLoader 的 num_workers 参数设置大于 0,那么数据加载过程将在多个线程中并行进行,这样可以加快数据加载的速度。但是,需要注意的是,在 Windows 系统上,由于多线程的实现与 Unix-like 系统不同,可能会遇到死锁或其他运行时错误。

内存管理:DataLoader 还负责管理内存使用情况,确保数据在训练过程中有效地加载和释放,以避免内存不足的问题。

迭代器生成:最终,DataLoader 生成一个迭代器,该迭代器可以在训练循环中迭代,每次迭代提供一个新的批次数据。

参数

1. dataset (Dataset)

Map-style:torch.utils.data.Dataset:子类必须重写__getitem__(),选择性重写__len__(),dataset[idx]可以得到idx-th个元素和对应标签。
Iterable-style:torch.utils.data.IterableDataset:适用流数据,适用于随机读取成本高甚至不可能的情况,以及批处理大小取决于所获取的数据的情况。子类必须重写__iter__(),希望做到不同进程配置不同副本,比如:

def __iter__(self):
	worker_info = torch.utils.data.get_worker_info()
	if worker_info is None:  # 单进程返回完整迭代器
		iter_start = self.start
		iter_end = self.end
	else:  # 多进程变量per_worker通过将数据集的范围(self.end - self.start)除以工作进程的数量(worker_info.num_workers)来计算,结果使用math.ceil()向上取整。
		per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
		worker_id = worker_info.id
		iter_start = self.start + worker_id * per_worker
		iter_end = min(iter_start + per_worker, self.end)
	return iter(range(iter_start, iter_end))

init()也都得自己写吧,才能定义__getitem__()等吧。
还有
torch.utils.data.TensorDataset(*tensors):每个tensor一条数据,照第0个维度索引

input_tensor = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32)
output_tensor = torch.tensor([0, 1, 0], dtype=torch.long)
dataset = torch.utils.data.TensorDataset(input_tensor, output_tensor) # 第0个维度要相同

torch.utils.data.StackDataset(*args, **kwargs):可以做多种数据集的集成

images = ImageDataset()
texts = TextDataset()
tuple_stack = StackDataset(images, texts)
tuple_stack[0] == (images[0], texts[0])
dict_stack = StackDataset(image=images, text=texts)
dict_stack[0] == {'image': images[0], 'text': texts[0]}

torch.utils.data.ConcatDataset(datasets):数据集拼接

mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
cifar10_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
concat_dataset = ConcatDataset([mnist_train, cifar10_train])

torch.utils.data.Subset(dataset, indices):根据indices划分子集

dataset (Dataset) – 完整数据集
indices (sequence) – 为子集选择的索引

2. batch_size (int, optional) – how many samples per batch to load (default: 1).

3. shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).

4. sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with len implemented. If specified, shuffle must not be specified.

iterable-style的数据集,他的迭代方式就靠数据集自己的定义啦,相当于Sampler or Iterable中的Iterable。
而Sampler or Iterable中的Sampler有很多种。
torch.utils.data.Sampler Sampler类中最基本的。

sampler = torch.utils.data.Sampler(dataset)
print(list(sampler)) # 输出的是一个索引列表

每个Sampler子类必须提供一个__iter__()方法,提供一种迭代数据集元素的索引或索引列表(批次)的方法,以及一个__len__()方法,返回返回的迭代器的长度。

torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None, generator=None):随机顺序的,replacement决定是否放回,有放回的情况下,可以通过num_samplers决定抽样个数,generator 参数是一个随机数生成器,可以指定。

# 设置全局随机数生成器的种子
torch.manual_seed(42)
# 创建一个数据集
data_source = [...]
# 创建一个随机抽样器,并使用全局随机数生成器
sampler = RandomSampler(data_source)

torch.utils.data.BatchSampler(sampler, batch_size, drop_last):包装另一个采样器以产生一个小批量的索引。drop_last决定最后不够装进一个batch的元素是否去掉。

sampler = torch.utils.data.BatchSampler(randomsampler)
print(list(sampler)) # 输出的是一个有很多个batch_size大小的索引列表的列表
class InfiniteSampler(torch.utils.data.Sampler):
    """ Infinite Sampler for PyTorch.
    Inspired from : https://github.com/facebookresearch/DomainBed
    Args:
        sampler (torch.utils.data.Sampler): Sampler to be used for the infinite sampling.
    """
    def __init__(self, sampler):
        self.sampler = sampler
    def __iter__(self):
        while True: # 这个判断条件使不断迭代,结束条件是外部明显的结束,比如woods中,包装成dataloader后,5000个step,每个step会get_next_batch,即输出一个batch,执行下面循环中的一步,5000个step中会循环输出batch_sampler中的batch,且注意,batch_sampler中的sampler又是randomsampler,即每一次遍历都会重新分配(应该是)
            for batch in self.sampler: # 循环输出batch_sampler中index的batch
                yield batch
    def __len__(self):
        return len(self.sampler)

5. batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.

其实上面的BatchSampler和InfiniteSampler都算此类,在Dataloader中也是作为batch_sampler参数。

6. num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)

加载数据的子进程数。

7. collate_fn (Callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.

把一个batch中数据整理成相同格式的。

8. pin_memory (bool, optional) – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.

9. drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)

10. timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)

11. worker_init_fn (Callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)

12. multiprocessing_context (str or multiprocessing.context.BaseContext, optional) – If None, the default multiprocessing context of your operating system will be used. (default: None)

13. generator (torch.Generator, optional) – If not None, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate base_seed for workers. (default: None)

14. prefetch_factor (int, optional, keyword-only arg) – Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default value depends on the set value for num_workers. If value of num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2).

15. persistent_workers (bool, optional) – If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)

16. pin_memory_device (str, optional) – the device to pin_memory to if pin_memory is True.

  • 19
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值