在对数据批量处理以进行训练时,
(一)首先我们需要数据库集
我们建立两个向量X和Y,一个作为输入的数据,一个作为正确的结果:
随后我们需要把X和Y组成一个完整的数据集,并转化为pytorch能识别的数据集类型:
现在看数据的数据类型
(二)就是把上一步做成的数据集放入Data.DataLoader中,可以生成一个迭代器,从而我们可以方便的进行批处理
DataLoader内部参数介绍
dataset:Dataset类型,从其中加载数据 batch_size:int,可选。每个batch加载多少样本 shuffle:bool,可选。为True时表示每个epoch都对数据进行洗牌 sampler:Sampler,可选。从数据集中采样样本的方法。 num_workers:int,可选。加载数据时使用多少子进程。默认值为0,表示在主进程中加载数据。 collate_fn:callable,可选。 pin_memory:bool,可选 drop_last:bool,可选。True表示如果最后剩下不完全的batch,丢弃。False表示不丢弃。
(三)通过DataLoader训练
好啦,现在我们就可以愉快的用我们上面定义好的迭代器进行训练啦。
在这里我们利用print来模拟我们的训练过程,即我们在这里对搭建好的网络进行喂入
输出的结果是:
可以看到,我们一共训练了所有的数据训练了5次。数据中一共10组,我们设置的mini-batch是3,即每一次我们训练网络的时候喂入3组数据,到了最后一次我们只有1组数据了,比mini-batch小,我们就仅输出这一个。
此外,还可以利用python中的enumerate(),是对所有可以迭代的数据类型(含有很多东西的list等等)进行取操作的函数,用法如下:
转载自:pytorch中如何使用DataLoader对数据集进行批处理 - 不愿透漏姓名的王建森 - 博客园