DataLoader和Dataset是pytorch中数据读取的核心
能够将自己的科研数据输入到模型当中是进行研究的最初的一步也是最重要的一步
一、torch.utils.data.Dataset
该类是用来定义数据从哪里读取,以及如何读取的问题。
在使用该类时,需要先继承该类。
那么具体如何读取,需要复写其中的方法
__getitem__() # 最为重要, 即每次怎么读数据,接受一个item索引,返回该索引的样本
__len__() #len()返回值的是常数
具体的类的实现如下:
class MyDataSet(Dataset):
def __init__(self, graphs, labels):
self.graphs = graphs
self.labels = labels
def __len__(self):
##返回数据的长度
return len(self.labels)
def __getitem__(self, item):
# __getitem__函数的作用是根据索引index遍历数据
#loc是根据index来索引对应的行
graph = graphs[item]
g = self.labels[item]
return [graph, g]
二. torch.utils.data.DataLoader
构建可迭代的数据装载器, 我们在训练的时候,由于数据量太大,内存吃不消需要分批次送入模型当中。
每一个for循环,每一次iteration,就是从DataLoader中获取一个batch_size大小的数据的。
class torch.utils.data.DataLoader(dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=<function default_collate>,
pin_memory=False,
drop_last=False)
参数很多,但需要关注的点就几个
dataset: Dataset类, 决定数据从哪读取以及如何读取
bathsize: 批大小
num_works: 是否多进程读取机制(参与工作的线程数)
shuffle: 每个epoch是否乱序(取batch是否随机取, 默认为False)
drop_last: 当样本数不能被batchsize整除时, 是否舍弃最后一批数据
collate_fn:对取出的batch进行处理(有时候根据需求需要自己定义) 下面详细讲该参数
三、dataloader之collate_fn
dataloader=DataLoader(dataset,batch_size=2)
batch_size=2即一个batch里面会有2个数据。我们以第1个batch为例,DataLoader会根据dataset取出前2个数据,然后弄成一个列表:
batch=[dataset[0],dataset[1]]
然后将上面这个batch作为参数交给collate_fn这个函数进行进一步整理数据,然后得到real_batch,作为返回值。如果你不指定这个函数是什么,那么会调用pytorch内部的collate_fn(将列表转化成tensor)。
定义其实很容易上手
def my_collate(batch):#batch上面说过,是dataloader传进来的。
graphs, labels = map(list, zip(*bathc))
real_batch=***
return real_batch
其中的real_batch= *** 写成你要的操作即可