Pytorch中的dataset和dataloader解析

当我们用 PyTorch 来训练神经网络时,经常需要用到 DatasetDataLoader 这两个类。它们都是 PyTorch 中的数据处理工具,用于读取和处理大量的数据,并将其转换为可供神经网络使用的格式。

Dataset

Dataset 类是一个抽象类,定义了读取数据集的接口方法。我们可以通过继承 Dataset 类,并实现其中的 __len__()__getitem__() 方法来创建自定义的数据集。其中:

  • __len__() 方法返回数据集中样本的数量;
  • __getitem__() 方法可以根据给定的索引(从 0 开始)获取数据集中对应的一个样本。

下面是一个简单的示例,该示例演示了如何从 CSV 文件中读取数据,并将其转换为 PyTorch 的 Tensor 对象:

import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file)
        self.X = self.data.iloc[:, :-1].values
        self.y = self.data.iloc[:, -1:].values

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        X = torch.tensor(self.X[index], dtype=torch.float32)
        y = torch.tensor(self.y[index], dtype=torch.float32)
        return X, y

在上述代码中,我们首先通过 Pandas 库将 CSV 文件读入到内存中;然后在 __init__() 方法中将数据集的 X 和 y 值分别存储到 self.Xself.y 中;最后在 __getitem__() 方法中,根据给定的索引获取对应的样本,并将其转换为 PyTorch 的 Tensor 对象。

DataLoader

DataLoader 类是一个可迭代的对象,可以用于批量加载数据集。它可以从 Dataset 对象(或其他类似于数组的对象)中获取数据,并在训练过程中生成随机批次的数据。其核心思想是使用多个 worker 进程并行地预取和处理数据,以加快数据读取速度,并在内存中组织成批次形式。

下面是一个简单的示例:

from torch.utils.data import DataLoader

# 创建 MyDataset 对象
my_dataset = MyDataset('data.csv')

# 创建 DataLoader 对象
batch_size = 32
dataloader = DataLoader(my_dataset, batch_size=batch_size, shuffle=True)

# 遍历 DataLoader 对象
for i, (X_batch, y_batch) in enumerate(dataloader):
    # 在此处进行模型训练或测试
    pass

在上述代码中,我们首先创建了一个 MyDataset 对象,并将其传递给了 DataLoader 构造函数中。我们还指定了 batch_size 参数,表示每个批次中包含的样本数量。设置 shuffle=True 表示每个 epoch 开始时,都会对数据集进行重新洗牌,以增加训练的随机性。

在实际训练过程中,我们可以通过遍历 DataLoader 对象来获取随机批次的数据,并将其用于模型训练或测试。在遍历 DataLoader 对象时,每个批次的数据都以元组 (X_batch, y_batch) 的形式返回,其中 X_batchy_batch 分别表示一个批次中的特征和标签。

补充

下面介绍一下 DataLoader 是如何从 Dataset 对象中获取数据,并生成随机批次的:

  1. 在创建 DataLoader 对象时,需要传入一个 Dataset 对象,表示要加载的数据集。
  2. DataLoader 会依次遍历 Dataset 中的每个样本,并将其添加到一个存储样本的列表中。此时,如果设置了 num_workers 参数,则会启动对应数量的 worker 进程,每个进程都会异步地预取和处理数据,并将结果添加到列表中。
  3. 当列表中的元素数量等于 batch_size 参数指定的值时,DataLoader 就会将这些元素封装成一个 batch,并返回给用户。同时,如果设置了 shuffle 参数为 True,还会随机打乱列表中的元素顺序,以增加数据的随机性。
  4. 如果 Dataset 中的样本数量不足以填满一个 batch,那么 DataLoader 会将剩余样本封装成一个小于 batch_size 的 batch 并返回。另外,如果 drop_last 参数设置为 True,则会舍弃最后一个小于 batch_size 的 batch,以保证每个 batch 中样本数量一致。

通过这种方式,DataLoader 可以高效、方便地从数据集中加载数据,并将其封装成批次形式,供模型训练使用。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值