PyTorch提供两个数据类:torch.utils.data.DataLoader
和torch.utils.data.Daaset
,可以让你使用预先加载的数据集和你自己的数据集。Dataset储存样本和对应的labels,DataLoader在Dataset上包装了一个iterable,用来获取样本
PyTorch库提供了许多预加载的数据集,比如FashionMNIST,这些数据集是torch.utils.data.Dataset
的子集,针对特定的数据集数据做了函数实现。它们可以被用来prototype和benchmark你的模型
载入PyTorch中存在的数据集
下面是一个从TorchVision中加载FashionMNIST数据集的例子,FashionMNIST是Zalando文章中的数据集,包括60,000张训练样本和10,000个测试样本。每个样本由一个28x28的灰度图和一个10 classes中的标签组成
加载FashionMNIST Dataset使用的参数如下:
root
训练和测试数据存储的地方train
指定是训练集还是测试集download
如果root中没有数据,是否需要从网络下载transform
和target_transform
指定图片和label的transformation
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
加载后的training_data和test_data的类型是<class 'torchvision.datasets.mnist.FashionMNIST'>
,但是可以进行迭代,每个元素是一个tuple,包含图片的tensor和label的数字(0-9之间)。每个tensor的shape为[1, 28, 28]
,即 c, h, w
在下面的custom Dataset class,我们实现的__getitem__
函数的返回语句为return image, label
,它所返回的就是该Dataset类实例化后的dataset对象的一个元素,即img tensor(或者别的训练用例)和label
为自己的数据创建Custom Dataset
一个custom Dataset类必须要实现3个函数:__init__
,__len__
和__getitem__
。在下面的实现中,FashionMNIST 图片存储在路径img_dir中,labels储存在有个CSV文件annotations_file中
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
__init__
__init__
函数在创建Dataset对象时运行一次。我们初始化包含图片的路径、标注文件路径、两个transforms
CSV文件labels.csv:
tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
__len__
__len__
函数返回数据集中的样本的个数
例如:
def __len__(self):
return len(self.img_labels)
__getitem__
__getitem__
函数根据给定的索引idx从数据集中加载并返回一个样本。根据这个索引,函数将找到图片在硬盘中的位置,用read_image
将之转化成tensor,并且从csv数据中取得对应的标签,再调用transform函数,最后用tuple的形式返回tensor image和对应的label
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
使用DataLoader准备训练数据
Dataset类创建的dataset对象是一个诸如<class ‘torchvision.datasets.mnist.FashionMNIST’>的对象,前面已经介绍过,它可以迭代,每个元素是样本和标签的tuple,这样的数据不能直接用于训练,PyTorch还提供了一个 DataLoader 类,将这些样本组合成一个一个的‘minibatch’,用迭代的方式遍历整个epoch,并且在一个epoch结束后shuffle数据,它还能使用Python的multiprocessing来加速数据取回
Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)