【python】pytorch中如何使用DataLoader对数据集进行批处理

第一步:

我们要创建torch能够识别的数据集类型(pytorch中也有很多现成的数据集类型,以后再说)。

首先我们建立两个向量X和Y,一个作为输入的数据,一个作为正确的结果:

    

随后我们需要把X和Y组成一个完整的数据集,并转化为pytorch能识别的数据集类型:

    

我们来看一下这些数据的数据类型:

     

可以看出我们把X和Y通过Data.TensorDataset() 这个函数拼装成了一个数据集,数据集的类型是【TensorDataset】。

好了,第一步结束了

 


 

第二步:

就是把上一步做成的数据集放入Data.DataLoader中,可以生成一个迭代器,从而我们可以方便的进行批处理。

     

DataLoader中也有很多其他参数:

复制代码

dataset:Dataset类型,从其中加载数据 
batch_size:int,可选。每个batch加载多少样本 
shuffle:bool,可选。为True时表示每个epoch都对数据进行洗牌 
sampler:Sampler,可选。从数据集中采样样本的方法。 
num_workers:int,可选。加载数据时使用多少子进程。默认值为0,表示在主进程中加载数据。 
collate_fn:callable,可选。 
pin_memory:bool,可选 
drop_last:bool,可选。True表示如果最后剩下不完全的batch,丢弃。False表示不丢弃。

复制代码

好了,第二步结束了,

 


 

第三步:

好啦,现在我们就可以愉快的用我们上面定义好的迭代器进行训练啦。

在这里我们利用print来模拟我们的训练过程,即我们在这里对搭建好的网络进行喂入。

     

输出的结果是:

      

可以看到,我们一共训练了所有的数据训练了5次。数据中一共10组,我们设置的mini-batch是3,即每一次我们训练网络的时候喂入3组数据,到了最后一次我们只有1组数据了,比mini-batch小,我们就仅输出这一个。

此外,还可以利用python中的enumerate(),是对所有可以迭代的数据类型(含有很多东西的list等等)进行取操作的函数,用法如下:

      

 

好啦,结束。

转载自:https://www.cnblogs.com/JeasonIsCoding/p/10168753.html

PytorchDataLoader是一个方便的数据加载器,它可以批量地加载数据,并在训练神经网络时提供数据。DataLoader的主要作用是将数据集分成批次,并且在每个epoch对数据进行随机化,以避免模型过度拟合。 在使用DataLoader之前,需要先定义一个数据集,并将其传递给DataLoader数据集需要实现__getitem__和__len__方法,以便DataLoader可以获取每个样本以及数据集的大小。 例如,一个简单的数据集可以如下所示: ```python class MyDataset(Dataset): def __init__(self, data): self.data = data def __getitem__(self, idx): return self.data[idx] def __len__(self): return len(self.data) ``` 然后,可以使用DataLoader数据集进行批处理: ```python dataset = MyDataset(data) dataloader = DataLoader(dataset, batch_size=32, shuffle=True) ``` 在上面的代码,batch_size参数指定了每个批次的样本数量,shuffle参数指定是否对数据进行随机化。 一旦创建了DataLoader,就可以通过迭代器访问数据集的批次。例如: ```python for batch in dataloader: # 处理当前批次的数据 ``` 需要注意的是,每个批次返回的是一个tensor的列表,而不是单个tensor。这是因为在训练神经网络时,通常需要对输入数据和标签进行分离处理。因此,每个批次包含输入数据和对应的标签。可以使用torch.Tensor.split()方法将tensor列表分离成输入和标签。
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值