机器学习的五大模块:
数据模块又可分为以下几部分:
● 数据的收集:Image、label
● 数据的划分:train、test、valid
● 数据的读取:DataLoader,有两个子模块,Sampler和Dataset,Sampler是对数据集生成索引index,DataSet是根据索引读取数据
● 数据预处理:torchvision.transforms模块
所以这一节主要介绍pytorch中数据的读取模块
一、DataLoader
torch.utils.data.DataLoader():构建可迭代的数据装载器,在训练数据时,每一个for循环,就是一次iteration,就是从DataLoader中获取一个batchsize大小的数据
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None, *, prefetch_factor=2,
persistent_workers=False)
常用的参数:
dataset:Dataset类,决定从哪读取以及如何读取数据;
batch_size:int型,批量的大小
shuffle:每个epoch的数据是否打乱
num_workers:是否进行多进程读取数据,若采取多进程,减少读取数据的时间,可以加速模型训练
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据
另外:Epoch,Iteration,Batchsize的区别
Epoch:所有的数据都输入到模型中,称为一个Epoch;
Iteration:一批样本输入到模型中,为一次Iteration;
Batchsize:输入到模型中的一批样本的大小;
例:假设样本总数是80,batchsize是8,那么一个epoch=10次Iteration;假设样本总数是87,batchsize是8,如果drop_last=True,最后一批数据不满足batchsize 8,舍去,一个epoch=10次Iteration,若drop_last=False,不舍去最后一批数据, 一个epoch=11次Iteration,最后一次的Iteration有7个样本。
二、Dataset
torch.utils.data.Dataset():Dataset类,所有自定义的数据集都要继承这个类,并且复写__getitem__()这个类方法;定义数据从哪里读取以及如何读取。
class Dataset(object):
def __init__(self):
pass
def __len__(self):
raise NotImplementedError
def __getitem__(self,index)