前言
将数据集划分完训练集、验证集、测试集后,便可以进行模型的训练了,对于任意一个模型的训练,都离不开的一个话题便是数据读取。
在PyTorch中数据读取的核心是Dataloader。Dataloader分为Sampler和DataSet两个子模块。Sampler的功能是生成索引,即样本序号;DataSet的功能是根据索引读取样本和对应的标签。
DataLoader
参数详解:
DataLoader(dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_works=0,
clollate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None)
● dataset:Dataset类,决定数据从哪读取及如何读取
● batchsize:批大小
● num_works:是否多进程读取数据,当数据越大时建议将它设大,否则可能跑很慢
● shuffle:每个epoch是否乱序
● drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据
● sampler
:用于从数据集中抽取样本的方法,可以是torch.utils.data.Sampler
的子类实例。如果设置了sampler
,则shuffle
参数将被忽略
●collate_fn
:用于将一个batch的数据进行整理和组合的函数,默认为None。通常用于文本、图像等需要整理的数据。
●pin_memory
:是否将数据存储在CUDA固定内存中,默认为False。如果设置为True,可以提高GPU数据传输速度
●timeout
:从工作进程中获取数据的超时时间,默认为0(无超时)
●worker_init_fn
:用于初始化工作进程的函数,默认为None。可以用于设置工作进程的种子等
●multiprocessing_context
:用于指定多进程上下文的类型,可以是'fork'(默认)或'spawn'
概念辨析:
● epoch:所有训练样本都已输入到模型中,称为一个epoch
● iteration:一批样本输入到模型中,称之为一个iteration
● batchsize:批大小,决定一个epoch中有多少个iteration
样本总数:80,batchsize:8 (样本能被batchsize整除)
● 1(epoch) = 10(iteration)
样本总数:87,batchsize=8 (样本不能被batchsize整除)
● drop_last = True:1(epoch) = 10(iteration)
● drop_last = False:1(epoch)= 11(iteration)