Pytorch官方文档:API for pytorch
DataLoader函数功能:
- 生成数据集的可迭代对象;
- 利用多线程加速batch data处理;
- 简洁、高效、直观的用于网络输入的数据结构,使用灵活,便于扩展
DataLoader类位于torch.utils.data包下,官方API介绍如下:
常用参数说明:
- dataset(Dataset):输入数据集
- batch_size(int, optional): 每个batch送入多少数据集
- shuffle(bool, optional): 是否进行重新排列
实例:加载MNIS数据集并转化为dataloader格式:
import torch
import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms
#define hyperparameter
EPOCH = 1
BATCH_SIZE = 64
TIME_STEP = 28 #time_step / image_height
INPUT_SIZE = 28 #input_step / image_width
LR = 0.01
DOWNLOAD = True
#get the mnist dataset
train_data = dsets.MNIST(root='./', train=True, transform= torchvision.transforms.ToTensor(), download=DOWNLOAD)
#use dataloader to batch input dateset
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
#......#