超大规模数据集类的创建
前面我们只接触了数据可全部储存于内存的数据集,这些数据集对应的数据集类在创建对象时就将所有的数据加载到内存。然而如果数据集规模超级大,我们很难有足够大的内存完全存下所有数据。所以需要一个按需加载样本到内存的数据集类。
Dataset类
在PyG中,我们通过继承torch_geometric.data.Dataset基类来自定义一个按需加载样本到内存的数据集类。继承torch_geometric.data.InMemoryDataset基类要实现的方法,继承此基类同样要实现,此外还需实现以下方法:
- len():返回数据集中的样本的数量
- get():实现加载单个图的操作。在内部,getitem()返回通过调用get()来获取Data对象,并根据transform参数对它们进行选择性转换。
通过下面的方法我们可以不用定义一个Dataset类,而直接生成一个Dataloader对象,直接用于训练:
from torch_geometric.data import Data, DataLoader
data_list = [Data(...), ..., Data(...)]
loader = DataLoader(data_list, batch_size=32)
我们也可以通过下面的方式将一个列表的Data对象组成一个batch:
from torch_geometric.data import Data, Batch
data_list = [Data(...), ..., Data(...)]
loader = Batch.from_data_list(data_list, ba