DataLoader 的基本用法

DataLoader

在深度学习中,数据加载和预处理是训练模型的关键步骤。PyTorch 提供了 DataLoader 用于简化这一过程。本文将详细介绍 PyTorch 中 DataLoader 的使用,包括基本用法、常见参数及自定义数据集的方式。

基本概念

什么是 DataLoader

DataLoader 是 PyTorch 中的一个类,用于将数据集(通常是一个 Dataset 对象)打包成一个可迭代的对象,方便在训练过程中逐批次读取数据。DataLoader 可以处理数据的随机打乱、并行加载、多线程加载等。

DataLoader 的基本构造

DataLoader 从数据集中取样本,并能够在多个线程中异步读取数据,这是其设计的关键点。


from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

dataset: 一个继承自 Dataset 的对象,定义了如何获取数据。
batch_size: 指定了每个批次的数据量。
shuffle: 是否在每轮迭代时打乱数据。
num_workers: 加载数据时使用的子进程数,默认值为0,即使用主进程。

使用 DataLoader 加载数据

示例数据集
使用一个简单的数据集来演示 DataLoader 的基本用法。


from torch.utils.data import DataLoader, Dataset

import torch

class SimpleDataset(Dataset):

    def __init__(self):

        self.data = torch.arange(100).float().unsqueeze(1)  # 100个样本每个样本包含一个特征

        self.labels = torch.arange(100).float()  # 标签与数据相同

    def __len__(self):

        return len(self.data)

    def __getitem__(self, idx):

        x = self.data[idx]

        y = self.labels[idx]

        return x, y

dataset = SimpleDataset()

创建 DataLoader


dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=2)

for batch in dataloader:

    inputs, labels = batch

    print(inputs, labels)

DataLoader 的常见参数

batch_size
批次大小,默认为1。指每次迭代所返回的数据量。


dataloader = DataLoader(dataset, batch_size=32)

shuffle
是否在每个 Epoch 开始时打乱数据。默认为 False。


dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

num_workers
加载数据时使用的子进程数。默认为 0,即在主进程中加载数据。如果设置为一个大于 0 的数值,则会使用多个进程来加载数据。


dataloader = DataLoader(dataset, batch_size=32, num_workers=4)

自定义 Dataset

有时,内置的数据集无法满足我们的需求,此时我们需要自定义数据集。自定义 Dataset 需要继承 torch.utils.data.Dataset 类并重写 lengetitem 方法。


class CustomDataset(Dataset):

    def __init__(self, data, labels):

        self.data = data

        self.labels = labels

    def __len__(self):

        return len(self.data)

    def __getitem__(self, idx):

        x = self.data[idx]

        y = self.labels[idx]

        return x, y

data = torch.randn(100, 3)  # 100个样本,每个样本包含3个特征

labels = torch.randint(0, 2, (100,))  # 100个标签,值为0或1

custom_dataset = CustomDataset(data, labels)

dataloader = DataLoader(custom_dataset, batch_size=10, shuffle=True)

for batch in dataloader:

    inputs, labels = batch

    print(inputs, labels)

数据转换

在实际应用中,通常需要对数据进行一定的预处理,这时可以使用 torchvision.transforms


from torchvision import transforms

class TransformedDataset(Dataset):

    def __init__(self, data, labels, transform=None):

        self.data = data

        self.labels = labels

        self.transform = transform

    def __len__(self):

        return len(self.data)

    def __getitem__(self, idx):

        x = self.data[idx]

        y = self.labels[idx]

        if self.transform:

            x = self.transform(x)

        return x, y

transform = transforms.Compose([

    transforms.Normalize(mean=[0.5], std=[0.5])

])

data = torch.randn(100, 1)

labels = torch.randint(0, 2, (100,))

transformed_dataset = TransformedDataset(data, labels, transform=transform)

dataloader = DataLoader(transformed_dataset, batch_size=10, shuffle=True)

for batch in dataloader:

    inputs, labels = batch

    print(inputs, labels)

总结

本文介绍了 PyTorch 中 DataLoader 的基本用法、常见参数、自定义数据集以及数据转换。通过合理使用 DataLoader,可以有效简化数据加载和预处理过程,从而专注于模型开发和优化

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值