pytorch学习笔记之Dataset与DataLoader

Dataset与DataLoader的目的是将训练与数据分开,从而让使用者更加关于住训练,而非组织数据。首先来看一下Dataset.

Dataset

Dataset的导入

Dataset从torch.utils.data导入

from torch.utils.data import Dataset

Dataset的使用

目前,有许多标准数据集已经使用Dataset进行了封装,所以,在使用这类数据集时,直接使用即可:

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)
  • root表示数据存储位置
  • train表示是否作为训练数据集
  • dowload表示在root所在的位置找不到数据时,是否从网上下载
  • transform表示将data做数据转换。

自定义Dataset

对于Dataset来说, 有三个必不可少的函数,分别为__init__, __len__, 和__getitem__, 他们的功能分别为:

  • __init__ 对Dataset进行初始化;
  • __len__ 返回Dataset的长度;
  • __getitem__返回对应的数据和标签
    以下给出一个例子:
import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

DataLoader

DataLoader 是 PyTorch 中用于加载数据的一个实用工具,能够帮助有效地加载和组织数据,用于训练和评估深度学习模型。DataLoader 本质上是一个迭代器,可以从自定义的 Dataset 类中加载数据,并提供多线程数据加载、批量处理、数据随机洗牌等功能,从而简化了数据的准备过程。

DataLoader的引入

from torch.utils.data import Dataset, DataLoader

DataLoader的使用

下面是 DataLoader 的基本使用方式以及其常用参数的意义和用法:

from torch.utils.data import DataLoader

# 创建自定义的 Dataset 实例
dataset = ...  # 自定义的 Dataset 对象

# 设置批次大小、是否随机洗牌、以及并行加载的线程数
batch_size = 32
shuffle = True
num_workers = 4

# 创建 DataLoader 实例
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
for imgs, labels in dataloader:
	...
	

参数解释:

  1. dataset:要加载数据的自定义 Dataset 对象。

  2. batch_size:每个批次的样本数量。在训练中,模型会同时处理这么多个样本,从而加快训练速度。

  3. shuffle:是否在每个 epoch(训练轮次)之前随机洗牌数据,以避免模型在连续的批次中看到相同的数据。

  4. num_workers:并行加载的线程数,用于异步加载数据。设置较大的值可以加速数据加载,但注意不要超过系统的核心数。

除了上述常用参数外,DataLoader 还有其他一些参数,例如:

  • collate_fn:用于对批次中的样本进行自定义的合并操作,可以用于处理不同大小的样本数据。

  • drop_last:如果数据样本数量不能被批次大小整除,设置为 True 会丢弃最后一个不完整的批次。

  • pin_memory:如果设为 True,会将数据加载到 CUDA 的固定内存中,加速数据传输到 GPU。

  • timeout:用于设置加载数据的超时时间。

  • worker_init_fn:用于每个 worker 线程的初始化函数。

  • sampler 参数用于指定数据采样的策略。它决定了数据在每个 epoch 中的遍历顺序。你可以使用不同的采样策略来影响模型的训练效果。

  • batch_sampler 参数允许你自定义批次采样的策略。与简单地设置 batch_size 不同,batch_sampler 可以控制每个批次中的样本数量和样本顺序。

使用 DataLoader 可以大大简化数据加载的过程,使数据准备的流程更加高效和灵活。根据项目的需求,可以调整不同的参数来优化数据加载和训练过程。

samplerbatch_sampler参数

对于某些任务,samplerbatch_sampler可能需要自己重写,定义。Pytorch提供了五种sampler的基类,可以对基类进行重写,自定义。

  1. SequentialSampler:按照数据集中样本的顺序进行采样。
from torch.utils.data import DataLoader, SequentialSampler

# 使用 SequentialSampler 进行顺序采样
dataset = MyDataset(...)
sampler = SequentialSampler(dataset)  # 顺序采样
data_loader = DataLoader(dataset, batch_size=32, sampler=sampler, num_workers=4)
  1. RandomSampler:随机地从数据集中采样样本。
from torch.utils.data import DataLoader, RandomSampler

# 使用 RandomSampler 进行随机采样
dataset = MyDataset(...)
sampler = RandomSampler(dataset)  # 随机采样
data_loader = DataLoader(dataset, batch_size=32, sampler=sampler, num_workers=4)
  1. SubsetRandomSampler:从指定的子集中随机采样样本
from torch.utils.data import DataLoader, SubsetRandomSampler

# 从子集中进行随机采样
indices = [0, 3, 5, 7, 9]  # 选定的样本索引
sampler = SubsetRandomSampler(indices)  # 子集随机采样
data_loader = DataLoader(dataset, batch_size=32, sampler=sampler, num_workers=4)
  1. WeightedRandomSampler:根据指定的样本权重进行随机采样,适用于不均衡数据集。
from torch.utils.data import DataLoader, WeightedRandomSampler

# 使用 WeightedRandomSampler 进行加权随机采样
dataset = MyDataset(...)
weights = [0.2, 0.3, 0.1, 0.4]  # 样本权重
sampler = WeightedRandomSampler(weights, num_samples=len(dataset), replacement=True)  # 加权随机采样
data_loader = DataLoader(dataset, batch_size=32, sampler=sampler, num_workers=4)
  1. BatchSampler:用于自定义批次采样的策略,可以控制每个批次中的样本数量和样本顺序。
from torch.utils.data import DataLoader, BatchSampler

# 使用 BatchSampler 进行自定义批次采样
dataset = MyDataset(...)
batch_size = 32
batch_sampler = BatchSampler(dataset, batch_size=batch_size, drop_last=True)  # 自定义批次采样
data_loader = DataLoader(dataset, batch_sampler=batch_sampler, num_workers=4)

以上示例演示了不同采样器的用法,可以根据你的数据集和需求选择适合的采样器。注意,samplershuffle参数不一起使用,即指定sampler后,不再设置shuffle.

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值