官方文档解释很好:
Code for processing data samples can get messy and hard to maintain; we ideally want our dataset code to be decoupled from our model training code for better readability and modularity. PyTorch provides two data primitives: torch.utils.data.DataLoader and torch.utils.data.Dataset that allow you to use pre-loaded datasets as well as your own data. Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset to enable easy access to the samples.
一目了然了 Dataset 和 DataLoader 是什么,为什么?
- Why:更好的可读性和模块化,和模型训练的核心代码解耦(风格清爽,便于维护)
- What:Dataset存储了样本和标签,DataLoader封装了迭代器(iterable),以便访问Dataset存储的数据
Loading a Dataset
PyTorch 提供了很多 pre-loaded datasets,继承自 torch.utils.data.Dataset,并针对不同的数据实现了特殊的函数。如果想实验一些想法,可直接利用这些数据集作为 prototype,或者用来对比 benchmark 衡量模型的效果。
注意:dataset 也可以直接访问元素,但却不能实现 shuffle、多进程加载 等高级操作。
以 TorchVision 的 Fashion-MNIST 为例。
加载数据:
from torchvision import datasets
from torchvision.transforms import ToTensor
# 以下定义的是 dataset,还不是 dataloader
train_data = datasets.FashionMNIST(root="data", train=True, download=True, transform=ToTensor())
test_data = datasets.FashionMNIST(root="data", train=False, download=True, transform=ToTensor())
可视化数据:
import matplotlib.plot as plt
figure = plt.figure(figsize=(8, 8))
rows = cols = 3
for i in range(1, rows*cols+1):
sample_idx = torch.randint(len(train_data), size(1,)).item() # 只含有一个元素的tensor可以用.item()取值
img, label = train_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(label) # label一般是id
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
Custom Dataset
自定义 dataset 必须实现三个内置函数:init,len,getitem。
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, meta_file, img_dir, transform=None, target_trans=None):
self.img_labels = pd.read_csv(meta_file) # 读入标注文件
self.img_dir = img_dir
self.transform = transform
self.target_trans = target_trans
def __len__(self): # number of samples in the dataset
return len(self.img_labels)
def __getitem(self, idx): # load and return a sample from the datset at idx
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image) # afterwards, the image becomes a tensor
if self.target_trans:
label = self.target_trans(label)
return image, label
Custom DataLoader
如上实现可见,Dataset是一个存储类型的数据结构,但它的访问接口并不适合训练过程。比如:
- 可一次拿一个样本,但训练时我们通常一次拿多个样本(i.e. minibatch)
- 每个 epoch 需要 reshuffle 数据避免过拟合(数据规律重复)
- 没有多进程加载数据的机制
这些弊端在 DataLoader 类中都可以抽象完成,迎刃而解。
from torch.utils.data import DataLoader
# 将 Dataset 对象加载到 DataLoader,所有 batch 都采样过一遍后对整个数据集 shuffle 一次
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
# iteration
train_features, train_labels = next(iter(train_dataloader))
img = train_features[0].squeeze()
label = train_labels[0]
plt.title(label)
plt.imshow(img, cmap="gray")
plt.show()
关于 PyTorch DataLoader 更细节的文档,可移步:torch.utils.data