torch.utils.data.DataLoader使用方法

本文详细介绍了PyTorch中DataLoader的功能与用法,DataLoader是PyTorch提供的数据加载工具,能结合数据集和取样器,实现数据的批处理和多线程读取,适用于模型训练过程中的数据管理,通过示例展示了如何使用DataLoader将数据集分为多个批次,进行高效的数据迭代。

部分转发

https://www.cnblogs.com/demo-deng/p/10623334.html

PyTorch 中的数据类型 torch.utils.data.DataLoader

数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集。
在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化。

"""
    批训练,把数据变成一小批一小批数据进行训练。
    DataLoader就是用来包装所使用的数据,每次抛出一批数据
"""
import torch
import torch.utils.data as Data

BATCH_SIZE = 5

x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
# 把数据放在数据库中
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    # 从数据库中每次抽出batch size个样本
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
)


def show_batch():
    for epoch in range(3):
        for step, (batch_x, batch_y) in enumerate(loader):
            # training


            print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))


if __name__ == '__main__':
    show_batch()

根据设置每个epoch 进行shuffle

### 功能概述 `torch.utils.data.DataLoader` 是 PyTorch 中用于数据加载的重要工具类,它能够将数据集包装成一个可迭代的对象,方便地进行数据加载和批处理操作。该类支持多种功能,包括批量读取、数据打乱、多线程并行加载、自动将数据打包成 batch、数据预处理和增强等 [^3]。 ### 常见参数详解 - **dataset**:传入的 `Dataset` 对象(如自定义或 `torchvision.datasets`)。 - **batch_size**:每个 batch 的样本数量。 - **shuffle**:是否打乱数据(通常训练集为 `True`)。 - **num_workers**:并行加载数据的线程数(越大越快,但依机器决定)。 - **drop_last**:是否丢弃最后一个不足 `batch_size` 的 batch。 - **pin_memory**:若为 `True`,会将数据复制到 CUDA 的 page-locked 内存中(加速 GPU 训练)。 - **collate_fn**:自定义打包 batch 的函数(可用于变长序列、图神经网络等)。 - **sampler**:控制数据采样策略,不能与 `shuffle` 同时使用。 - **persistent_workers**:若为 `True`,worker 在 epoch 间保持运行状态(提高效率,PyTorch 1.7+)。 ### 基本使用示例 ```python from torch.utils.data import Dataset, DataLoader class MyDataset(Dataset): def __init__(self): self.data = [i for i in range(100)] def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] dataset = MyDataset() loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=2) for batch in loader: print(batch) ``` ### 自定义 `collate_fn` 示例 当需要处理变长序列或其他特殊情况时,可以通过自定义 `collate_fn` 来实现更灵活的数据打包方式。 ```python def collate_fn(batch): # 自定义打包逻辑 return [item * 2 for item in batch] loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=2, collate_fn=collate_fn) for batch in loader: print(batch) ``` ### 相关问题 1. 如何在 PyTorch 中自定义 Dataset 并与 DataLoader 结合使用? 2. 如何利用 num_workers 参数加速数据加载? 3. 在什么情况下需要自定义 collate_fn 函数? 4. 如何在使用 DataLoader 时进行数据增强? 5. 如何理解 pin_memory 参数的作用及其对性能的影响?
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值