pytorch中使用Dataset和DataLoader创建自定义数据集 入门

介绍

pytorch中,我们可以使用torch.utils.data.DataLoadertorch.utils.data.Dataset加载数据集,具体来说,可以简单理解为Dataset是数据集,他提供数据与索引之间的映射,同时也要有标签。而DataLoader是将Dataset中的数据迭代提取出来,从而能够提供给模型。
所以,具体流程是,我们应该先按照要求先建立一个Dataset,之后再建立一个DataLoader,然后就可以用了。
pytorch中有很多现成的数据集,我们下载就可以使用。但是更多时候我们要建立自己的数据集,我也是入门,所以先建立一个带标签的图像数据集。

参考

建立Dataset

我们可以继承torch.utils.data.Dataset类,必须要重写__init__, __len__, 和 __getitem__这三个函数。其中 __len__能够返回我们数据集中的数据个数,__getitem__能够根据索引返回数据。

前提

我们有一个文件夹,里面有很多猫、狗和汽车的照片,此外有一个csv文件,里面是每张照片对应的类别,也就是标签。我们根据这个照片文件夹和csv文件,来建立我们的带标签数据集。

  • 对于图片文件夹:0——29张图片为猫,30——59张图片为狗,其他为汽车。
    图片文件夹
  • 对于标签csv文件,每一行中首先是图片名,然后是类别。其中0代表猫,1代表狗,2代表汽车。如下图:
    标签文件
具体代码
import os
from torchvision.io import read_image
import pandas as pd
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import numpy as np

class myImageDataset(Dataset):
    def __init__(self, img_dir, img_label_dir, transform=None):
        super().__init__()
        self.img_dir = img_dir
        self.img_labels = pd.read_csv(img_label_dir)  # 这是一个dataframe,0是文件名,1是类别
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)  # 数据集长度
    
    def __getitem__(self, index):
        # 拼接得到图片文件路径
        # 例如img_dir为'D:/curriculum/2022learning/learnning_dataset/data/'
        # img_labels.iloc[index, 0]为5.jpg
        # 那么img_path为'D:/curriculum/2022learning/learnning_dataset/data/5.jpg'
        img_path = os.path.join(self.img_dir + self.img_labels.iloc[index, 0])
        image = read_image(img_path)  # tensor类型
        label = self.img_labels.iloc[index, 1]
        if self.transform is not None:
            image = self.transform(image)  # 对图片进行某些变换
        
        return image, label

代码中都有注释。

__init__()

类的初始化函数,其中img_dir为图片文件夹的根目录,img_label_dir为标签文件路径,transform为对数据项进行的变换。

__len__()

返回数据集长度。

__getitem__()

根据index,返回其在数据集中对应的数据和标签。

验证

通过如下代码,我们具体输出一张图片:

# 把图片对应的tensor调整维度,并显示
def tensorToimg(img_tensor):
    img = img_tensor.numpy()
    img = np.transpose(img_tensor, [1, 2, 0])
    plt.imshow(img)


label_dic = {0: 'cat', 1: 'dog', 2: 'car'}

label_path = 'D:/curriculum/2022learning/learnning_dataset/labels.csv'
img_root_path = 'D:/curriculum/2022learning/learnning_dataset/data/'
dataset = myImageDataset(img_root_path, label_path)

image, label = dataset.__getitem__(33)
print(image.shape)
print(label_dic[label])
tensorToimg(image)

结果
可以看到,数据集中,图片变为tensor,维度为[通道数,长,宽]。

DataLoader

之后就可以使用DataLoader对刚刚创建的数据集不断取出样本了。不再赘述。

dataloader = DataLoader(dataset, batch_size=5, shuffle=True)

这样,我们就建立了一个dataLoader。接下来我们输出一下看看:

for imgs, labels in dataloader:
    print(imgs.shape)
    print(labels)
    break

但是这里报错:stack expects each tensor to be equal size, but got [3, 268, 320] at entry 0 and [3, 480, 370] at ...,查询得知是数据集中图片大小不一,而这时Dataset中定义的参数transfom就派上了用场。我们让每张图片的大小都是224*224

from torch.utils.data import DataLoader
from torchvision import transforms

transform = transforms.Resize((224, 224))

dataset = myImageDataset(img_root_path, label_path, transform)
dataloader = DataLoader(dataset, batch_size=5, shuffle=True)
for imgs, labels in dataloader:
    print(imgs.shape)
    print(labels)
    break

结果为:

torch.Size([5, 3, 224, 224])
tensor([0, 2, 2, 2, 1])

由于batch_size是5,而每个图片的形状为[3, 224, 224],因此一个batch的数据形状为:[5, 3, 224, 224]

其他使用DataLoader的方法
for index, (imgs, labels) in enumerate(dataloader):
    print(index)
    print(imgs.shape)
    print(labels)
    break

结果为:

0
torch.Size([5, 3, 224, 224])
tensor([1, 2, 0, 1, 1])
imgs, label = next(iter(dataloader))
print(imgs.shape)
print(labels)

结果为:

torch.Size([5, 3, 224, 224])
tensor([1, 2, 0, 1, 1])

得到了一批的图片和对应的标签,我们就能将其输入到模型中,并使用标签和预测结果计算损失。

### 回答1: Dataset DataLoaderPyTorch 中用于加载处理数据的两个主要组件。Dataset 用于从数据源中提取加载数据,DataLoader 则用于将数据转换为适合机器学习模型训练的格式。 ### 回答2: 在PyTorch中,DatasetDataLoader是用于处理加载数据的两个重要类。 Dataset是一个抽象类,用于表示数据集对象。我们可以自定义Dataset子类来处理我们自己的数据集。通过继承Dataset类,我们需要实现两个主要方法: - __len__()方法:返回数据集的大小(样本数量) - __getitem__(idx)方法:返回索引为idx的样本数据 使用Dataset类的好处是可以统一处理训练集、验证集测试集等不同的数据集,将数据进行一致的格式化处理DataLoader是一个实用工具,用于将Dataset对象加载成批量数据。数据加载器可以根据指定的批大小、是否混洗样本多线程加载等选项来提供高效的数据加载方式。DataLoader是一个可迭代对象,每次迭代返回一个批次的数据。我们可以通过循环遍历DataLoader对象来获取数据。 使用DataLoader可以实现以下功能: - 数据批处理:将数据集划分为批次,并且可以指定每个批次的大小。 - 数据混洗:可以通过设置shuffle选项来随机打乱数据集,以便更好地训练模型。 - 并行加载:可以通过设置num_workers选项来指定使用多少个子进程来加载数据,加速数据加载过程。 综上所述,DatasetDataLoaderPyTorch中用于处理加载数据的两个重要类。Dataset用于表示数据集对象,我们可以自定义Dataset子类来处理我们自己的数据集。而DataLoader是一个实用工具,用于将Dataset对象加载成批量数据,提供高效的数据加载方式,支持数据批处理、数据混洗并行加载等功能。 ### 回答3: 在pytorch中,Dataset是一个用来表示数据的抽象类,它封装了数据集的访问方式数据的获取方法。Dataset类提供了读取、处理转换数据的功能,可以灵活地处理各种类型的数据集,包括图像、语音、文本等。用户可以继承Dataset类并实现自己的数据集类,根据实际需求定制数据集Dataloader是一个用来加载数据的迭代器,它通过Dataset对象来获取数据,并按照指定的batch size进行分批处理Dataloader可以实现多线程并行加载数据,提高数据读取效率。在训练模型时,通常将Dataset对象传入Dataloader进行数据加载,并通过循环遍历Dataloader来获取每个batch的数据进行训练。 DatasetDataloader通常配合使用Dataset用于数据的读取处理Dataloader用于并行加载分批处理数据。使用DatasetDataloader的好处是可以轻松地处理大规模数据集,实现高效的数据加载处理。此外,DatasetDataloader还提供了数据打乱、重复采样、数据划分等功能,可以灵活地控制数据的访问使用。 总之,DatasetDataloaderpytorch中重要的数据处理模块,它们提供了方便的接口功能,用于加载、处理管理数据集,为模型训练评估提供了便利。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值