Pytorch中iter(dataloader)的使用

本文介绍了PyTorch中的DataLoader如何作为可迭代对象工作,通过iter()和enumerate()访问数据集。示例展示了如何加载MNIST数据集,并以批次方式处理图像和标签。在使用enumerate()时,注意imgs和labels的顺序,它们分别代表了图像数据和对应的标签值。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

dataloader本质上是一个可迭代对象,可以使用iter()进行访问,采用iter(dataloader)返回的是一个迭代器,然后可以使用next()访问。
也可以使用enumerate(dataloader)的形式访问。
下面举例说明:

transformation = transforms.Compose([
    transforms.ToTensor()
])

train_ds = datasets.MNIST("./data", train=True, transform=transformation, download=True)

test_ds = datasets.MNIST("./data", train=False, transform=transformation, download=True)

train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=256)
#imgs, labels = next(iter(train_dl))

for labels, imgs in enumerate(train_dl): #如果imgs在前,labels在后,那么imgs将是标签形式,labels才是图片转化0~1之间的值。
    print("imgs:\t", imgs)
    print("labels:\t", labels)
labels:	 3
imgs:	 [tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]]), tensor([5, 1, 7, 6, 7, 9, 6, 4, 0, 4, 0, 4, 4, 4, 2, 7, 5, 2, 9, 2, 1, 9, 1, 8,
        2, 6, 8, 0, 1, 6, 1, 0, 3, 6, 6, 2, 5, 1, 3, 4, 4, 1, 8, 4, 8, 1, 2, 5,
        2, 0, 1, 3, 6, 6, 0, 1, 7, 6, 0, 8, 3, 7, 1, 6])]

iter(dataloader)访问时,imgs在前,labels在后,分别表示:图像转换0~1之间的值,labels为标签值。并且imgs和labels是按批次进行输入的。

transformation = transforms.Compose([
    transforms.ToTensor()
])

train_ds = datasets.MNIST("./data", train=True, transform=transformation, download=True)

test_ds = datasets.MNIST("./data", train=False, transform=transformation, download=True)

train_dl = DataLoader(train_ds, batch_size=64, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=256)
#imgs, labels = next(iter(train_dl))
"""
for labels, imgs in enumerate(train_dl):
    print("imgs:\t", imgs)
    print("labels:\t", labels)
"""
for imgs, labels in iter(train_dl):
    print("imgs:\t", imgs)
    print("label:\t", labels)
imgs:	 tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        ...,


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]],


        [[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]]])
label:	 tensor([6, 1, 0, 8, 6, 7, 8, 1, 3, 4, 8, 5, 8, 9, 7, 2, 9, 3, 0, 6, 1, 1, 4, 6,
        0, 6, 7, 9, 3, 7, 1, 3, 5, 2, 7, 1, 1, 0, 3, 0, 1, 0, 8, 7, 5, 1, 5, 6,
        3, 3, 1, 3, 8, 6, 8, 7, 6, 3, 8, 3, 1, 0, 2, 7])
### PyTorch DataLoader 的功能与使用方法 #### 1. 基础概念 `DataLoader` 是 PyTorch 中用于批量加载数据的核心工具之一。它通过封装 `Dataset` 对象,提供了一种高效的方式来处理大规模数据集并支持多线程读取[^1]。 #### 2. 参数详解 以下是 `DataLoader` 的主要参数及其作用: - **dataset**: 这是一个实现了 `__getitem__()` 和 `__len__()` 方法的对象,表示要加载的数据集合。 - **batch_size**: 定义每次迭代返回的样本数量,默认值为 1。 - **shuffle**: 如果设置为 True,则会在每个 epoch 开始前打乱数据顺序(仅当未指定 sampler 时有效)。默认值为 False。 - **sampler**: 自定义采样器对象,用于控制数据加载的顺序。如果指定了 sampler,则 shuffle 应该设为 None 或者不指定[^2]。 - **num_workers**: 表示用于数据加载的子进程数。增加此数值可以加速数据预处理过程,尤其是在 GPU 训练场景下推荐大于零的值。 - **collate_fn**: 用户自定义函数,用来合并一批次的数据样本到张量或其他结构化形式中去。如果没有特别需求的话会采用默认实现方式。 #### 3. 使用实例 下面展示如何创建一个简单的 `DataLoader` 并结合自定义 `Sampler` 来完成特定任务: ```python from torch.utils.data import Dataset, DataLoader, Sampler class MyCustomDataset(Dataset): def __init__(self, data_list): self.data = data_list def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] class CustomSampler(Sampler): def __init__(self, data_source): super().__init__(data_source) self.indices = list(range(len(data_source))) def __iter__(self): random.shuffle(self.indices) # 随机排列索引 return iter(self.indices) def __len__(self): return len(self.indices) # 创建数据集和采样器 my_dataset = MyCustomDataset([i for i in range(10)]) custom_sampler_instance = CustomSampler(my_dataset) # 初始化 Data Loader dataloader = DataLoader( my_dataset, batch_size=2, sampler=custom_sampler_instance, num_workers=0 ) for batch_data in dataloader: print(batch_data) ``` 上述代码片段展示了如何构建一个带有随机抽样的 `DataLoader` 实例[^3]。 #### 4. 数据增强 虽然 `DataLoader` 主要是负责数据分发的工作流管理,但它也可以配合其他库或者模块来进行图像变换等操作以达到数据扩增的目的。例如 torchvision.transforms 提供了一系列丰富的转换手段可以帮助我们轻松实现这一点。 ---
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值