pytorch学习_Dataset & Dataloader

同样是跟着Tutorial学的,博客主要是给自己看笔记。其他人首次学习可能还是直接看Tutorials效果更好一点。
Pytorch官方Totorial Datasets & DataLoaders

数据集

Pytorch提供了两个数据基元(不知道这样翻译准不准确,原文是data primitives)分别是torch.utils.data.DataLoadertorch.utils.data.Dataset,这两个基元允许你使用(pytorch)预先加载好的数据和你自己的数据。其中,后者Dataset存储着样本和对应的标签,DataLoaderDataset外封装一个可迭代对象,使我们方便获取样本。

另外Pytorch还提供了一些继承自torch.utils.data.Dataset的预加载好的数据(如FashionMNIST),这些数据本质上就是那个XXX.Dataset的子类,而且有很多方法。这些数据可以用来训练和测试我们的模型。

上面说的可能有点抽象,而且解释得不是很清楚,给个实在点的例子。

从TorchVision中加载Fasion-MNIST数据集:
Fasion-MNIST里有六万个训练样本和一万个训练样本,每个都是28*28灰度图像,共被分成十类。

加载FashionMNIST Dataset需要以下参数:
root:我们所训练/测试数据的路径
train:指定训练集
download=True:如果root中没有我们想要的数据,则从互联网上下载数据集。
transform和target_transform:指定标签和数据的转换。(这里transform可能有点模糊,在下一章中有transform的介绍)

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

这样之后数据就会被下载到同一个目录下的"data"文件中了。

data文件

数据的可视化

我们也可以将这些数据集可视化:

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

这串代码使接在上一串代码下面的,简单解释一下它在干嘛。
label_map是用来索引的。
sample_idx是随机选中训练集中所有照片中的其中一个,后面的.item()方法在pytorch官方tutorial中的tensor板块有介绍,但在我的上一份博客中没有讲,简单来说就是把一个1*1大小的tensor格式的数据转化成一个python类型的数据,如float32,float16,int16啥的。在这段代码中,转变成python的数据后就可以用来索引了。
training_data[0]是图片的像素信息,training_data[1]是图片的类别,用它通过label_map进行索引得到label
稍微解释一下再看一下代码就好理解了。

最后得到的图片也是随机的,随便放两个:
在这里插入图片描述
在这里插入图片描述

自定义数据集

本质上就是自定义类,但由于作者非cs科班学生,pyfthon并没有仔细学过类,自学C++最后由于时间原因中道崩殂,这里可能无法讲得太仔细。

自定义的数据集必须包含三个类方法: __init__,__len____getitem__
另外在定义类之前,需要把图像信息放在img_dir文件中,数据情况放在annotations_file

先上代码:

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

下面一个一个看过去,首先是init方法
每次实例化一个数据集对象,__init__方法都会被调用。它初始化了图像所在文件img_dir,标签文件annotations_file和转换方式(下一节中有详细介绍)

len方法

len方法比较简单,这还看不懂可以直接入土了。

getitem方法:
getitem方法如其名,作用就是通过给的idx,加载和返回图片和标签。第一句img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])通过os库确定索引的图片所在的路径。其中iloc是panda库里的一个索引方法。self.img_labels.iloc[idx,0]返回的是idx图像的文件名。利用os库的join把img_dir和后面得到的图像文件名给链接起来,就得到了该图像的路径。
第二句image = read_image(img_path)获得图像信息。
第三句label = self.img_labels.iloc[idx, 1]得到图像的标签,同样是用的pandas库里的方法。
下面几句就不解释啦。

DataLoaders

Dataset能一次性获取所有图像的特征(就是图像数据,以后都叫特征啦,features)和标签。但在训练时,我们常常需要把样本切分成几个batch分堆送去训练,对每一次epoch都会对数据重新洗牌,来防止过拟合。

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

现在把数据装进train_dataloader和test_dataloader之后,需要把他们用迭代的方式取出。

train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

最终打印的结果为:
Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 3

因为上面设置中shuffle=True所以我们每次得到的label和图像都是不一样的。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
### 回答1: DatasetDataLoaderPyTorch 中用于加载和处理数据的两个主要组件。Dataset 用于从数据源中提取和加载数据,DataLoader 则用于将数据转换为适合机器学习模型训练的格式。 ### 回答2: 在PyTorch中,DatasetDataLoader是用于处理和加载数据的两个重要类。 Dataset是一个抽象类,用于表示数据集对象。我们可以自定义Dataset子类来处理我们自己的数据集。通过继承Dataset类,我们需要实现两个主要方法: - __len__()方法:返回数据集的大小(样本数量) - __getitem__(idx)方法:返回索引为idx的样本数据 使用Dataset类的好处是可以统一处理训练集、验证集和测试集等不同的数据集,将数据进行一致的格式化和预处理。 DataLoader是一个实用工具,用于将Dataset对象加载成批量数据。数据加载器可以根据指定的批大小、是否混洗样本和多线程加载等选项来提供高效的数据加载方式。DataLoader是一个可迭代对象,每次迭代返回一个批次的数据。我们可以通过循环遍历DataLoader对象来获取数据。 使用DataLoader可以实现以下功能: - 数据批处理:将数据集划分为批次,并且可以指定每个批次的大小。 - 数据混洗:可以通过设置shuffle选项来随机打乱数据集,以便更好地训练模型。 - 并行加载:可以通过设置num_workers选项来指定使用多少个子进程来加载数据,加速数据加载过程。 综上所述,DatasetDataLoaderPyTorch中用于处理和加载数据的两个重要类。Dataset用于表示数据集对象,我们可以自定义Dataset子类来处理我们自己的数据集。而DataLoader是一个实用工具,用于将Dataset对象加载成批量数据,提供高效的数据加载方式,支持数据批处理、数据混洗和并行加载等功能。 ### 回答3: 在pytorch中,Dataset是一个用来表示数据的抽象类,它封装了数据集的访问方式和数据的获取方法。Dataset类提供了读取、处理和转换数据的功能,可以灵活地处理各种类型的数据集,包括图像、语音、文本等。用户可以继承Dataset类并实现自己的数据集类,根据实际需求定制数据集。 Dataloader是一个用来加载数据的迭代器,它通过Dataset对象来获取数据,并按照指定的batch size进行分批处理。Dataloader可以实现多线程并行加载数据,提高数据读取效率。在训练模型时,通常将Dataset对象传入Dataloader进行数据加载,并通过循环遍历Dataloader来获取每个batch的数据进行训练。 DatasetDataloader通常配合使用,Dataset用于数据的读取和预处理,Dataloader用于并行加载和分批处理数据。使用DatasetDataloader的好处是可以轻松地处理大规模数据集,实现高效的数据加载和预处理。此外,DatasetDataloader还提供了数据打乱、重复采样、数据划分等功能,可以灵活地控制数据的访问和使用。 总之,DatasetDataloaderpytorch中重要的数据处理模块,它们提供了方便的接口和功能,用于加载、处理和管理数据集,为模型训练和评估提供了便利。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值