pytorch中Dataset、Dataloader、Sampler、collate_fn相互关系和使用说明

        Dataset、Dataloader、Sampler和collate_fn是在PyTorch中用于处理数据加载和预处理的关键组件。它们相互关联并协同工作以有效地加载、处理训练和测试数据。

1.Dataset(数据集):

  • Dataset是一个抽象类,在PyTorch中用于表示数据集。
  • 为了使用Dataset类,你需要继承它并实现两个主要方法:__len__和__getitem__。__len__方法返回数据集的样本数量。__getitem__方法根据给定的索引返回一个样本。可以根据自己的数据集格式和需求进行实现。
  • Dataset类负责提供数据的访问接口,但不涉及数据的加载和处理。

2.Sampler(采样器):

  • Sampler是一个用于定义样本抽样策略的类。它决定了在每个epoch中以什么顺序访问数据集中的样本。
  • PyTorch提供了几种内置的Sampler类,如RandomSampler(随机采样)、SequentialSampler(顺序采样)和SubsetRandomSampler(子集随机采样)等。你也可以自定义Sampler类来实现特定的采样逻辑。
  • Sampler类不直接加载或处理数据,它只负责确定样本的访问顺序。

3.Dataloader(数据加载器):

  • Dataloader是一个用于加载数据的迭代器。它接受一个Dataset对象和一个Sampler对象(可选),并提供了一种从数据集中批量加载数据的方式。
  • Dataloader可以指定批量大小(batch size)、并行加载数据的线程数、是否打乱数据等参数。
  • 通过迭代Dataloader对象,你可以按照指定的批量大小逐批获取数据。

4.collate_fn(数据处理函数):

  • collate_fn是一个可选的参数,用于自定义如何处理从数据集中获取的样本。
  • 当你的样本具有不同的大小或数据类型时,collate_fn可以用于对样本进行预处理和转换,以便能够创建一个批次的张量(batch tensor)。
  • collate_fn接收一个样本列表作为输入,并返回经过处理的批次样本。
  • collate_fn通常用于在数据加载之前对样本进行预处理,例如填充变长序列或转换样本为张量或者进行特定的数据组织处理,如YOLOV5中将图片和label重新进行整合:将batch_size个[3, 640, 640]的矩阵拼成一个[batch_size, 3, 640, 640] 将label[n1,6]、[n2,6]、[n3,6]...拼接成[n1+n2+n3+..., 6] 

5.使用说明 

  1. 创建自定义Dataset类,继承Dataset并实现__len__和__getitem__方法。
  2. 根据需要选择合适的Sampler类,或者使用默认的顺序采样。
  3. 创建Dataloader对象,传入Dataset对象和Sampler对象(可选),设置批量大小和其他参数。
  4. 可选地定义一个collate_fn函数来处理样本列表中的样本,以便能够创建一个批次的张量。
  5. 使用迭代器方式访问Dataloader对象,以批量获取数据进行训练或测试。

6.代码示例

        代码模拟DataLoader:

import torch
from torch.utils.data import BatchSampler

#模拟生成数据集
image = torch.randn(10, 2)
label = torch.randint(low=0, high=2,size=(10,)).float()
dataset = torch.utils.data.TensorDataset(image, label)
print("dataset_len = {0}".format(len(dataset)))
for i in range(len(dataset)): #打印数据
    print('dataset[{0}]= {1}'.format(i, dataset[i]))

############# 以下代码模拟实现DataLoader相关功能 #############
#1.生成抽样索引,相当于DataLoader中的Sampler和BatchSampler的实现
sampler = torch.utils.data.RandomSampler(data_source=dataset)  #数据随机采样
batch_sampler = torch.utils.data.BatchSampler(sampler=sampler, batch_size=4,drop_last=False);
for i in batch_sampler:
    index = i #返回列表 BatchSampler内部有yield
    break;
print("batch索引index:", index)

#2.取出一个batch的样本
one_batch_sample = [dataset[i] for i in index]
print("batch数据:", one_batch_sample)

#3. 数据处理成feature和labels, 模拟DataLoader中的collate_fn函数
def collate_fn(batch):
    features = torch.stack([data[0] for data in batch])
    labels = torch.stack([data[1] for data in batch])
    return features,labels

features,labels = collate_fn(batch=one_batch_sample)
print("获取features为:", features)
print("获取的labels为:", labels)
  • 28
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值