pytorch 数据读取之(Dataset,DataLoader)


因为每次和数据打交道,天天可以碰到 torch.utils.data.Dataset, torch.utils.data.DataLoader

我看到的代码 都是一步一步封装,首先定义数据增强的措施,然后把这些措施封装到预处理中(这里用到了torchvision.transforms),定义好预处理后,就应该采样了sample(这里继承了Dataset,一般用RandomSampler,BatchSampler,SequentialSampler),里面封装了预处理,接着进行数据的加载loader(这里继承了Dataloader),这里封装的就是之前定义的sample,最后dataloader封装进了dataloaderIter中,进行逐步迭代。当然还需要对数据集进行定义,编写需要的方法(这里继承了Dataset)。 当程序开始执行的时候,它一步一步倒着去执行,依次遍历,首先 根据得到batch索引,然后根据索引得到数据,接着进行处理,最终得到所需要的数据。

DataLoaderIter && DataLoader

Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
一般先进行定义:

 loader_train = LTRLoader('train', dataset_train, 
                            training=True, 
                            batch_size=settings.batch_size,
                            num_workers=settings.num_workers,
                            shuffle=True, drop_last=True, stack_dim=1)

其中dataset_train,就是Dataset类型,也就是定义的sampler,这里最主要的是重写__ getitem__()方法,得到数据。
一般运行代码:
for i,data in enumenrate(loader,1):
就自动跳到了dataloader.pyclass DataLoader()中的 def __ iter 方法,在这里插入图片描述
在这里选择使用单线程还是多线程进行数据的迭代,这里的MultiProcessingDataLoaderIter( Iterates once over the DataLoader’s dataset, as specified by the sampler)继承的是BaseDataLoaderIter,开始初始化,然后Dataloader进行初始化,然后进入
next __()方法 随机生成索引,进而生成batch,最后调用 _get_data() 方法得到data。idx, data = self._get_data(), data = self.data_queue.get(timeout=timeout)
这里用到了
队列

总结一下:
1.调用了dataloader 的__iter__() 方法, 产生了一个DataLoaderIter
2.反复调用DataLoaderIter 的__next__()来得到batch, 具体操作就是, 多次调用dataset的__getitem__()方法 (如果num_worker>0就多线程调用), 然后用collate_fn来把它们打包成batch. 中间还会涉及到shuffle , 以及sample 的方法等,
3当数据读完后, next()抛出一个StopIteration异常, for循环结束, dataloader 失效.

因此dataloader作用:
1.定义了一堆成员变量, 到时候赋给DataLoaderIter,
2.然后有一个__iter__() 函数, 把自己 “装进” DataLoaderIter 里面.
以下是DataLoader()的参数:
在这里插入图片描述
其中shuffle参数和sampler参数相关,sampler option is mutually exclusive with shuffle
Dataloader中存在一个默认的collate_fn函数
在这里插入图片描述
需要根据自己的需求重写collate_fn函数,该函数的作用是将得到的数据整理成一个batch。
代码为:先判断batch数据的类型,然后形成batch
在这里插入图片描述

Dataset

基本的Dataset类如下:有__init__() 、__ getitem__()、__ len__()、__ iter__() 方法
在这里插入图片描述
当然你定义的数据集是需要继承此类,并且覆写 _ init_() 、__ getitem__()、_ len_() 方法,甚至自己实现get_frames(),get_sequence_info()等方法
如果你要加载你的数据也需要继承此类,并且覆写 _ init_() 、__ getitem__()、_ len_() 方法,可以调用数据集中自己写的方法

第一个初始化,定义你用于训练的数据集,以什么比例进行sample(多个数据集的情况),每个epoch训练样本的数目,预处理方法等等
第二个是根据索引得到所需要的数据

video_keys=self.videos.keys() 
video = self.videos[video_keys[rand_vid]] #rand_vid为索引
video_ids=video[0]
video_id_keys=video.ids.keys()
rand_trackid_z = np.random.choice(list(range(len(video_id_keys))))
  
  #simafc中经过一系列的路径,文件名,进行随机选择 需要的图片,如果有预处理方法,再得到图片后进行预处理。有时候除了加载所需要的图片,还要加载真值

第三个是返回数据的长度

要想进行下一步的操作,读取正确的数据,并进行一定的处理是很重要的。

  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch中,数据读取是构建深度学习模型的重要一环。为了高效处理大规模数据集,PyTorch提供了三个主要的工具:DatasetDataLoader和TensorDatasetDataset是一个抽象类,用于自定义数据集。我们可以继承Dataset类,并重写其中的__len__和__getitem__方法来实现自己的数据加载逻辑。__len__方法返回数据集的大小,而__getitem__方法根据给定的索引返回样本和对应的标签。通过自定义Dataset类,我们可以灵活地处理各种类型的数据集。 DataLoader数据加载器,用于对数据集进行批量加载。它接收一个Dataset对象作为输入,并可以定义一些参数例如批量大小、是否乱序等。DataLoader能够自动将数据集划分为小批次,将数据转换为Tensor形式,然后通过迭代器的方式供模型训练使用。DataLoader数据准备和模型训练的过程中起到了桥梁作用。 TensorDataset是一个继承自Dataset的类,在构造时将输入数据和目标数据封装成Tensor。通过TensorDataset,我们可以方便地处理Tensor格式的数据集。TensorDataset可以将多个Tensor按行对齐,即将第i个样本从各个Tensor中取出,构成一个新的Tensor作为数据集的一部分。这对于处理多输入或者多标签的情况非常有用。 总结来说,Dataset提供了自定义数据集的接口,DataLoader提供了批量加载数据集的能力,而TensorDataset则使得我们可以方便地处理Tensor格式的数据集。这三个工具的配合使用可以使得数据处理变得更加方便和高效。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值