pytorch自定义数据集

我们做深度学习大部分时候的数据都是以数据+标注(CV)或者是纯文本(NLP)的形式存在的。

在开始一个项目时首先面对的就是如何把未经处理的数据整合成torch能识别的tensor。为此,torch提供了抽象类Datasets,它能很方便的把你的数据封装成一个可迭代的DataLoader供你使用。

要自定义数据集,首先要继承抽象类torch.utils.data.Dataset,并且需要重载两个重要的函数:__len__ 和__getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。

import torch
from torch.utils import data
class MyDataset(data.Dataset):
    def __init__(self):
        super(MyDataset, self).__init__()
        self.data = torch.randn(8,2)#八个数据,两个一组
    def __getitem__(self, index):
        img,label=self.data[index][0],self.data[index][1]
        return img,label
    def __len__(self):
        return self.data.size()[0]
mydata = MyDataset()

在有标注(例如csv文件)时,我们可以简单的将csv转化为列表来完成__getitem__和__len__操作,__len__需要我们返回自己数据集的长度,__getitem__需要我们返回遍历时每次需要读取的数据(例如图片+标注数据集就返回img和label)

这样,我们自己的数据集就定义好了。接下来需要加载。加载之后的dataloader对象就可以直接遍历了。

print(len(mydata))
data_loader = data.DataLoader(mydata,batch_size=2,shuffle=False)
for img,label in enumerate(data_loader):
    print(img,labbel)

在更多时候我们需要将数据提前处理成对应shape的tensor,这就是数据预处理了,例如图像增强之类的操作都可以在__init__里面写。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值