Dataset,Dataloader详解

Dataset,Dataloader详解

Dataset,Dataloader是什么?

  • Dataset:负责可被Pytorch使用的数据集的创建
  • Dataloader:向模型中传递数据

为什么要了解Dataloader

​ 因为你的神经网络表现不佳的主要原因之一可能是由于数据不佳或理解不足。 因此,以更直观的方式理解、预处理数据并将其加载到网络中非常重要。

​ 通常,我们在默认或知名数据集(如 MNIST 或 CIFAR)上训练神经网络,可以轻松地实现预测和分类类型问题的超过 90% 的准确度。 但是那是因为这些数据集组织整齐且易于预处理。 但是处理自己的数据集时,我们常常无法达到这样高的准确率

Dataloader 的使用

  • 载入相关类
from torch.utils.data import Dataloader
  • 设置相关参数
from torch.utils.data import DataLoader

DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
 )
"""
dataset:是数据集
batch_size:是指一次迭代中使用的训练样本数。通常我们将数据分成训练集和测试集,并且我们可能有不同的批量大小。
shuffle:是传递给 DataLoader 类的另一个参数。该参数采用布尔值(真/假)。如果 shuffle 设置为 True,则所有样本都被打乱并分批加载。否则,它们会被一个接一个地发送,而不会进行任何洗牌。
num_workers:允许多处理来增加同时运行的进程数
collate_fn:合并数据集
pin_memory:锁页内存:将张量固定在内存中
"""

以minist为例子

# Import MNIST
from torchvision.datasets import MNIST

# Download and Save MNIST 
data_train = MNIST('~/mnist_data', train=True, download=True)

# Print Data
print(data_train)
print(data_train[12])

#Dataset MNIST Number of datapoints: 60000 Root location: /Users/viharkurama/mnist_data Split: Train (<PIL.Image.Image image mode=L size=28x28 at 0x11164A100>, 3)

现在让尝试提取元组,其中第一个值对应于图像,第二个值对应于其各自的标签。 下面是代码片段:

import matplotlib.pyplot as plt

random_image = data_train[0][0]
random_image_label = data_train[0][1]

# Print the Image using Matplotlib
plt.imshow(random_image)
print("The label of the image is:", random_image_label)

让我们使用 DataLoader 类来加载数据集,如下所示。

import torch
from torchvision import transforms

data_train = torch.utils.data.DataLoader(
    MNIST(
          '~/mnist_data', train=True, download=True, 
          transform = transforms.Compose([
              transforms.ToTensor()
          ])),
          batch_size=64,
          shuffle=True
          )

for batch_idx, samples in enumerate(data_train):
      print(batch_idx, samples)

这就是我们使用 DataLoader 加载简单数据集的方式。 但是,我们不能总是对每个数据集都依赖已经有的数据集,要是自己的数据集怎么办

定义自己的数据集

我们将创建一个由数字和文本组成的简单自定义数据集

先介绍两个方法

#__getitem__() 方法通过索引返回数据集中选定的样本。

#__len__() 方法返回数据集的总大小。例如,如果您的数据集包含 1,00,000 个样本,则 len 方法应返回 1,00,000。

class Dataset(object):
    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

​ 创建自定义数据集并不复杂,但作为加载数据的典型过程的附加步骤,有必要构建一个接口以获得良好的抽象(至少可以说是一个很好的语法糖)。 现在我们将创建一个包含数字及其平方值的新数据集。 让我们将数据集称为 SquareDataset。 其目的是返回 [a,b] 范围内的值的平方。 下面是相关代码:

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

class SquareDataset(Dataset):
     def __init__(self, a=0, b=1):
         super(Dataset, self).__init__()
         assert a <= b
         self.a = a
         self.b = b
        
     def __len__(self):
         return self.b - self.a + 1
        
     def __getitem__(self, index):
        assert self.a <= index <= self.b
        return index, index**2

data_train = SquareDataset(a=1,b=64)
data_train_loader = DataLoader(data_train, batch_size=64, shuffle=True)
print(len(data_train))

​ 在上面的代码块中,我们创建了一个名为 SquareDataset 的 Python 类,它继承了 PyTorch 的 Dataset 类。 接下来,我们调用了一个 init() 构造函数,其中 a 和 b 分别被初始化为 0 和 1。 超类用于从继承的 Dataset 类中访问 len 和 get_item 方法。 接下来我们使用 assert 语句来检查 a 是否小于或等于 b,因为我们想要创建一个数据集,其中值将位于 a 和 b 之间。

​ 然后,我们使用 SquareDataset 类创建了一个数据集,其中数据值的范围为 1 到 64。我们将其加载到名为 data_train 的变量中。 最后,Dataloader 类在 data_train_loader 中存储的数据上创建了一个迭代器,batch_size 初始化为 64,shuffle 设置为 True。

如何使用transform

​ 当你学会怎么定义自己的数据集的时候,你可能会想要更近 一步的操作,对于你自己的数据集进行剪切或者变换

​ 以CIFAR10为例子

  • 将所有图像调整为 32×32
  • 对图像应用中心裁剪变换
  • 将裁剪后的图像转换为张量
  • 标准化图像

导入必要的模块

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

接下来,我们将定义一个名为 transforms 的变量,我们在其中按顺序编写所有预处理步骤。我们使用 Compose 类将所有转换操作链接在一起。

transform = transforms.Compose([
    # resize
    transforms.Resize(32),
    # center-crop
    transforms.CenterCrop(32),
    # to-tensor
    transforms.ToTensor(),
    # normalize
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

"""
resize:此调整大小转换将所有图像转换为定义的大小。在这种情况下,我们要将所有图像的大小调整为 32×32。因此,我们将 32 作为参数传递。
center-crop:接下来,我们使用 CenterCrop 变换裁剪图像。 我们发送的参数也是分辨率/大小,但由于我们已经将图像大小调整为 32x32,因此图像将与此裁剪中心对齐。 这意味着图像将从中心裁剪 32 个单位(垂直和水平)。
to-tensor:我们使用 ToTensor() 方法将图像转换为张量数据类型。
normalize:这将张量中的所有值归一化,使它们位于 0.5 和 1 之间。

"""

在下一步中,在执行我们刚刚定义的转换之后,我们将使用 trainloader 将 CIFAR 数据集加载到训练集中。

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=False)
  • 5
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
DataLoaderPyTorch 中用于数据加载和批处理的实用工具。它可以帮助您在训练神经网络时高效地处理数据集。下面是 DataLoader 的常见参数的详细解释: 1. dataset: 这是您要加载和处理的数据集对象。它应该是一个可迭代对象,例如一个 PyTorchDataset 对象。 2. batch_size: 这个参数指定了每个批次中的样本数量。默认值是 1,表示每个批次中只包含一个样本。较大的 batch_size 可以提高训练速度,但可能会占用更多的内存。 3. shuffle: 如果将该参数设置为 True,则会在每个 epoch(训练周期)开始时对数据进行洗牌(随机排序),以增加样本之间的独立性。默认值为 False。 4. sampler: 如果不想使用随机洗牌,可以通过指定一个 Sampler 对象来自定义样本的顺序。Sampler 对象可以根据特定的逻辑来对样本进行采样,例如按类别平衡采样。如果指定了 sampler,那么 shuffle 参数将被忽略。 5. batch_sampler: 类似于 sampler 参数,但是它返回一个批次的索引列表。这个参数可以与 batch_size 参数一起使用,用于自定义批处理的方式。 6. num_workers: 这个参数指定了在数据加载过程中使用的子进程数量。默认值为 0,表示在主进程中加载数据。较大的 num_workers 值可以提高数据加载的速度,但可能会占用更多的系统资源。 7. collate_fn: 这个参数用于指定如何将样本列表转换为批次的张量。默认情况下,它会使用 torch.stack() 来堆叠样本张量。您可以根据自己的需求自定义这个函数。 除了以上列出的参数之外,DataLoader 还有其他一些参数,用于控制如何处理数据集的边界情况、并行加载等。您可以查阅 PyTorch 官方文档以获取更详细的信息。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值