PyTorch Learn the basics系列:2. Datasets & DataLoaders

PyTorch提供两个数据类:torch.utils.data.DataLoadertorch.utils.data.Daaset,可以让你使用预先加载的数据集和你自己的数据集。Dataset储存样本和对应的labels,DataLoader在Dataset上包装了一个iterable,用来获取样本
PyTorch库提供了许多预加载的数据集,比如FashionMNIST,这些数据集是torch.utils.data.Dataset的子集,针对特定的数据集数据做了函数实现。它们可以被用来prototype和benchmark你的模型

载入PyTorch中存在的数据集

下面是一个从TorchVision中加载FashionMNIST数据集的例子,FashionMNIST是Zalando文章中的数据集,包括60,000张训练样本和10,000个测试样本。每个样本由一个28x28的灰度图和一个10 classes中的标签组成

加载FashionMNIST Dataset使用的参数如下:

  • root 训练和测试数据存储的地方
  • train 指定是训练集还是测试集
  • download 如果root中没有数据,是否需要从网络下载
  • transformtarget_transform 指定图片和label的transformation
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

加载后的training_data和test_data的类型是<class 'torchvision.datasets.mnist.FashionMNIST'>,但是可以进行迭代,每个元素是一个tuple,包含图片的tensor和label的数字(0-9之间)。每个tensor的shape为[1, 28, 28],即 c, h, w

在下面的custom Dataset class,我们实现的__getitem__函数的返回语句为return image, label,它所返回的就是该Dataset类实例化后的dataset对象的一个元素,即img tensor(或者别的训练用例)和label

为自己的数据创建Custom Dataset

一个custom Dataset类必须要实现3个函数:__init____len____getitem__。在下面的实现中,FashionMNIST 图片存储在路径img_dir中,labels储存在有个CSV文件annotations_file中

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

__init__

__init__函数在创建Dataset对象时运行一次。我们初始化包含图片的路径、标注文件路径、两个transforms
CSV文件labels.csv:

tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file)
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform

__len__

__len__函数返回数据集中的样本的个数

例如:

def __len__(self):
	return len(self.img_labels)

__getitem__

__getitem__ 函数根据给定的索引idx从数据集中加载并返回一个样本。根据这个索引,函数将找到图片在硬盘中的位置,用read_image将之转化成tensor,并且从csv数据中取得对应的标签,再调用transform函数,最后用tuple的形式返回tensor image和对应的label

def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
        image = self.transform(image)
    if self.target_transform:
        label = self.target_transform(label)
    return image, label

使用DataLoader准备训练数据

Dataset类创建的dataset对象是一个诸如<class ‘torchvision.datasets.mnist.FashionMNIST’>的对象,前面已经介绍过,它可以迭代,每个元素是样本和标签的tuple,这样的数据不能直接用于训练,PyTorch还提供了一个 DataLoader 类,将这些样本组合成一个一个的‘minibatch’,用迭代的方式遍历整个epoch,并且在一个epoch结束后shuffle数据,它还能使用Python的multiprocessing来加速数据取回

Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

扩展阅读

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值