我们做深度学习大部分时候的数据都是以数据+标注(CV)或者是纯文本(NLP)的形式存在的。
在开始一个项目时首先面对的就是如何把未经处理的数据整合成torch能识别的tensor。为此,torch提供了抽象类Datasets,它能很方便的把你的数据封装成一个可迭代的DataLoader供你使用。
要自定义数据集,首先要继承抽象类torch.utils.data.Dataset,并且需要重载两个重要的函数:__len__ 和__getitem__,前者提供了数据集的大小,后者支持整数索引,范围从0到len(self)。
import torch
from torch.utils import data
class MyDataset(data.Dataset):
def __init__(self):
super(MyDataset, self).__init__()
self.data = torch.randn(8,2)#八个数据,两个一组
def __getitem__(self, index):
img,label=self.data[index][0],self.data[index][1]
return img,label
def __len__(self):
return self.data.size()[0]
mydata = MyDataset()
在有标注(例如csv文件)时,我们可以简单的将csv转化为列表来完成__getitem__和__len__操作,__len__需要我们返回自己数据集的长度,__getitem__需要我们返回遍历时每次需要读取的数据(例如图片+标注数据集就返回img和label)
这样,我们自己的数据集就定义好了。接下来需要加载。加载之后的dataloader对象就可以直接遍历了。
print(len(mydata))
data_loader = data.DataLoader(mydata,batch_size=2,shuffle=False)
for img,label in enumerate(data_loader):
print(img,labbel)
在更多时候我们需要将数据提前处理成对应shape的tensor,这就是数据预处理了,例如图像增强之类的操作都可以在__init__里面写。