python里的Dataset和DataLoader

这篇博客详细介绍了PyTorch中的`Dataset`和`DataLoader`类,这两个类在处理图像数据时常用。`Dataset`是一个抽象基类,用于创建数据集并定义`__getitem__`和`__len__`方法。`DataLoader`则负责从数据集中批量加载和处理数据,支持多线程加载、批大小、采样策略等参数。博客还提到了如何自定义采样器和批处理函数以适应非整型索引的数据集。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

这两个类在加载图片时经常遇到,由于参数比较多,原文档又为英文,所以写篇博客记录一下。

class Dataset(Generic[T_co]):
    r"""一个抽象类表示为:class: ' Dataset '。
    
    所有表示从键到数据样本映射的数据集都应该子类化它。
    所有子类都应该覆写:方法: ' __getitem__ ',
    支持获取已给定键的数据样本。
    :方法:`__len__`, 被用来返回数据集大小通过许多
    :类:`~torch.utils.data.Sampler` 实现,
    且是:类:`~torch.utils.data.DataLoader`的默认参数.

    .. 注释::
      :class:`~torch.utils.data.DataLoader` 
      默认构造一个索引采样器来生成整型索引。
      要使它与带有非整型索引/键的映射样式的数据集一起工作,
      必须提供自定义采样器。
    """

    def __getitem__(self, index) -> T_co:
        raise NotImplementedError

    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])
class DataLoader(Generic[T_co]):
	r""" 参数:
		dataset (dataset):要从其中加载数据的dataset。
		
		batch_size (int,可选):每批要加载多少个样本
		( 默认值:' ' 1 '')。
		
		shuffle (bool,可选):设置为“True”后重新洗牌
		(默认值:' ' False ' ')。
		
		sampler (sampler或Iterable,可选):
		定义绘制策略数据集中的示例。
		可以是任何带有__len__的Iterable实现的。
		如果指定了,:attr: ' shuffle '必须不指定。
		
		batch_sampler (Sampler或Iterable,可选):
		像:attr: 'Sampler ',但是每次返回一批索引。
		与:attr:`batch_size`, :attr:`shuffle`,
		 :attr:`sampler`,和 :attr:`drop_last`相互排斥。
		
		num_workers (int,可选):为数据使用多少个子线程来加载数据。
		' ' 0 ' '表示数据将在主进程中加载。( 默认值:' ' 0 ' ')
		
		collate_fn (callable, optional):
		合并一个样本列表以形成一个mini-batch张量(s)。
		当使用成批装载从地图类型数据集。
		
		pin_memory (bool,可选):如果' ' True ' ',
		数据加载器将复制张量在返回它们之前进入CUDA固定内存。
		如果数据元素是一个自定义类型,或者:attr:
		 ' collate_fn '返回一个自定义类型的批处理

		drop_last (bool,可选):设置为“True”,
		删除最后一批不完整的批,
		如果数据集大小不能被批大小整除。
		如果' '假' '数据集的大小不能被批大小整除,
		然后是最后一批将会更小。(默认值:' '假' ')
		
		timeout(数字,可选):如果为正值,
		则从线程收集超时值。应该总是非负的。
		(默认值:' ' 0 ' ')
		
		worker_init_fn (callable, optional):
		如果不是' ' None ' ',将在每个节点上调用
		worker子进程的worker id(一个int in ' ' 
		[0, num_workers - 1] ' ')为输入,
		在播种之后和数据加载之前。(默认值:' '没有' ')
		
		prefetch_factor (int,可选,仅关键字参数):
		每个线程提前装样的次数。' ' 2 ' '
		表示将在所有workers中预取2 * num_workers样本。
		(默认值:' ' 2 ' ')

		persistent_workers (bool,可选):
		如果' ' True ' ',
		数据加载器不会在数据集被消耗一次后关闭工作进程。
		这允许维持“Dataset”实例线程存活。
		(默认值:' '假' ')"""
在 PyTorch 中,Dataloader 是一个非常常用的工具,用于将数据集加载到模型中,以便进行训练或测试。下面是一个简单的使用 DataLoader 的示例: 首先需要导入必要的包: ``` import torch from torch.utils.data import DataLoader, Dataset ``` 接下来,我们需要创建一个自定义的数据集类,继承 `Dataset` 类,并实现其中的 `__len__` `__getitem__` 方法: ``` class MyDataset(Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] ``` 在上面的代码中,`__len__` 方法返回数据集的大小,`__getitem__` 方法返回指定索引的数据。 接下来,我们需要将数据集实例化,并创建一个 DataLoader 对象: ``` data = [1, 2, 3, 4, 5] dataset = MyDataset(data) dataloader = DataLoader(dataset, batch_size=2, shuffle=True) ``` 在上面的代码中,`MyDataset` 类的实例 `dataset` 用于存储数据,`DataLoader` 类的实例 `dataloader` 用于将数据集加载到模型中。其中,`batch_size` 参数指定每个 batch 的大小,`shuffle` 参数指定是否随机打乱数据集。 最后,我们可以使用 `dataloader` 对象迭代数据集,以便将其加载到模型中: ``` for batch in dataloader: print(batch) ``` 在上面的代码中,`batch` 变量将依次包含每个 batch 的数据。我们可以在其中添加模型训练或测试的代码,以便进行模型训练或测试。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值