【PytorchBasics】Dataset & Dataloader

官方文档解释很好:

Code for processing data samples can get messy and hard to maintain; we ideally want our dataset code to be decoupled from our model training code for better readability and modularity. PyTorch provides two data primitives: torch.utils.data.DataLoader and torch.utils.data.Dataset that allow you to use pre-loaded datasets as well as your own data. Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset to enable easy access to the samples.

一目了然了 Dataset 和 DataLoader 是什么,为什么?

  • Why:更好的可读性和模块化,和模型训练的核心代码解耦(风格清爽,便于维护)
  • What:Dataset存储了样本和标签,DataLoader封装了迭代器(iterable),以便访问Dataset存储的数据

Loading a Dataset

PyTorch 提供了很多 pre-loaded datasets,继承自 torch.utils.data.Dataset,并针对不同的数据实现了特殊的函数。如果想实验一些想法,可直接利用这些数据集作为 prototype,或者用来对比 benchmark 衡量模型的效果。

注意:dataset 也可以直接访问元素,但却不能实现 shuffle、多进程加载 等高级操作。

以 TorchVision 的 Fashion-MNIST 为例。

加载数据:

from torchvision import datasets
from torchvision.transforms import ToTensor
# 以下定义的是 dataset,还不是 dataloader
train_data = datasets.FashionMNIST(root="data", train=True, download=True, transform=ToTensor())
test_data = datasets.FashionMNIST(root="data", train=False, download=True, transform=ToTensor())

可视化数据:

import matplotlib.plot as plt

figure = plt.figure(figsize=(8, 8))
rows = cols = 3
for i in range(1, rows*cols+1):
	sample_idx = torch.randint(len(train_data), size(1,)).item()  # 只含有一个元素的tensor可以用.item()取值
	img, label = train_data[sample_idx]  
	figure.add_subplot(rows, cols, i)
	plt.title(label)  # label一般是id
	plt.axis("off")
	plt.imshow(img.squeeze(), cmap="gray")
plt.show()

Custom Dataset

自定义 dataset 必须实现三个内置函数:initlengetitem

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
	def __init__(self, meta_file, img_dir, transform=None, target_trans=None):
		self.img_labels = pd.read_csv(meta_file)  # 读入标注文件
		self.img_dir = img_dir
		self.transform = transform
		self.target_trans = target_trans
	
	def __len__(self):  # number of samples in the dataset
		return len(self.img_labels)
	
	def __getitem(self, idx):  # load and return a sample from the datset at idx
		img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
		image = read_image(img_path)
		label = self.img_labels.iloc[idx, 1]
		if self.transform:
			image = self.transform(image)  # afterwards, the image becomes a tensor
		if self.target_trans:
			label = self.target_trans(label)
		return image, label

Custom DataLoader

如上实现可见,Dataset是一个存储类型的数据结构,但它的访问接口并不适合训练过程。比如:

  1. 可一次拿一个样本,但训练时我们通常一次拿多个样本(i.e. minibatch)
  2. 每个 epoch 需要 reshuffle 数据避免过拟合(数据规律重复)
  3. 没有多进程加载数据的机制

这些弊端在 DataLoader 类中都可以抽象完成,迎刃而解。

from torch.utils.data import DataLoader
# 将 Dataset 对象加载到 DataLoader,所有 batch 都采样过一遍后对整个数据集 shuffle 一次
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
# iteration
train_features, train_labels = next(iter(train_dataloader))
img = train_features[0].squeeze()
label = train_labels[0]
plt.title(label)
plt.imshow(img, cmap="gray")
plt.show()

关于 PyTorch DataLoader 更细节的文档,可移步:torch.utils.data

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值