pytorch的自定义数据集/DataLoader和Dataset重写

背景介绍

  做Modulation Recognition的时候需要加载自定义的数据集,这就涉及到DataLoader和Dataset类中的方法重写了。

DataLoader介绍

  源码中的介绍是:

*Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.*

  也就是说,我们可以通过输入一个数据集,及常用参数如:batch_size、shuffle,就可以得到一个打包好的迭代器。这个迭代器包含了batch_size的序号及根据batch_size分割好的数据块。

Dataset 介绍

  源码中的介绍是:

An abstract class representing a :class:`Dataset`.

  很短,但是很经典。这是一个抽象类。所谓抽象类就是类的抽象化,而类本身就是不存在的,所以抽象类无法实例化。它存在的意义就是被继承。而且继承抽象类的类必须要重写抽象类的方法。
  简单的说,我们构造一个MyDataset数据类,需要继承Dataset,并重写Dataset中的方法。

  去掉源码中的注释,Dataset抽象类的定义就五行代码,两个方法:

class Dataset(object):
    def __getitem__(self, index):
        raise NotImplementedError

    def __add__(self, other):
        return ConcatDataset([self, other])
        
    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

  根据我们的需要,我们会重写 __getitem__方法,以及__len__方法。

工作原理

  首先,我们要定义自己的数据集类,例如叫做MyDataset,则代码片段应该为:

class MyDataSet(Dataset):
    def __init__(self, data, label):
        self.data = data
        self.label = label
        self.length = data.shape[0]
        
    def __getitem__(self, mask):
        label = self.label[mask]
        data = self.data[mask]
        return label, data

    def __len__(self):
        return self.length

继承

  很简单

class MyDataSet(Dataset):

  表示我们MyDataSet类继承了抽象类Dataset。该MyDataSet类中的有三个方法。

__init__方法

  __init__方法是python中的构造方法(java中是叫构造方法,不知道python是不是这么叫,如果不是请大家指正),构造方法会在实例化对象时调用。其传入参数就是我们的数据集(data)和标签集(label)。

 def __init__(self, data, label):
        self.data = data
        self.label = label
        self.length = data.shape[0]

__getitem__方法

  __getitem__方法是获取返回数据的方法,传入参数是一个index,也被叫做mask,就是我们对数据集的选择索引。在自己使用时,比如想从data = [100, 99, 98, …, 0]的集合中选出下标为[0, 2, 4]的集合,则index/mask 就取[0, 2, 4],返回data[index]即可。
  其实在调用DataLoader时就会自己生成index,所以我们只需要写好方法即可。

 def __getitem__(self, mask):
        return self.label[mask], self.data[mask]

__len__方法

  偷了个懒没有去看源码。听说不给返回length的话pytorch会一脸xx。

 def __len__(self):
        return self.length

使用

  完成了MyDataSet,就可以通过DataLoader使用了。例如此处我已经有了一个X_train,其中的数据的每一个batch都代表了一个信号。Y_train当中都是X_train对应的标签。
  于是我的代码就是:

train_set = MyDataSet(data=X_train, label=Y_train)
num_epoch = 100     # number of epochs to train on
batch_size = 1024  # training batch size
train_data = DataLoader(train_set, batch_size=batch_size, shuffle=True)
for epoch in range(num_epoch ):
    model.train()
    for batchsz, (label, data) in enumerate(train_data):
        # i表示第几个batch, data表示该batch对应的数据,包含data和对应的labels
        print("第 {} 个Batch size of label {} and size of data{}".format(batchsz, label.shape, data.shape))

  DataLoader会根据设置的batch_size来产生index/mask,然后调用Datase的__getitem__方法取出数据。
  输出结果如下:
在这里插入图片描述
  接下来就可以愉快的写模型了!!!

总结

  其实看起来很简单的一个Dataset抽象类重写和DataLoader使用,包含了面向对象编程的三大特点:封装继承多态

  • 封装体现在Dataset抽象类的封装及我们的MyDataSet类的封装上。
  • 继承体现在我们MyDataSet继承Dataset抽象类上。
  • 多态体现在DataLoader对数据集的操作上(这点纯属个人理解,感觉有点像java中的向上转型,但python好像没有这一概念)。
  • 21
    点赞
  • 56
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
PyTorch中,数据读取是构建深度学习模型的重要一环。为了高效处理大规模数据集PyTorch提供了三个主要的工具:DatasetDataLoader和TensorDatasetDataset是一个抽象类,用于自定义数据集。我们可以继承Dataset类,并重写其中的__len__和__getitem__方法来实现自己的数据加载逻辑。__len__方法返回数据集的大小,而__getitem__方法根据给定的索引返回样本和对应的标签。通过自定义Dataset类,我们可以灵活地处理各种类型的数据集DataLoader是数据加载器,用于对数据集进行批量加载。它接收一个Dataset对象作为输入,并可以定义一些参数例如批量大小、是否乱序等。DataLoader能够自动将数据集划分为小批次,将数据转换为Tensor形式,然后通过迭代器的方式供模型训练使用。DataLoader在数据准备和模型训练的过程中起到了桥梁作用。 TensorDataset是一个继承自Dataset的类,在构造时将输入数据和目标数据封装成Tensor。通过TensorDataset,我们可以方便地处理Tensor格式的数据集。TensorDataset可以将多个Tensor按行对齐,即将第i个样本从各个Tensor中取出,构成一个新的Tensor作为数据集的一部分。这对于处理多输入或者多标签的情况非常有用。 总结来说,Dataset提供了自定义数据集的接口,DataLoader提供了批量加载数据集的能力,而TensorDataset则使得我们可以方便地处理Tensor格式的数据集。这三个工具的配合使用可以使得数据处理变得更加方便和高效。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值