引用自Datawhale
https://github.com/datawhalechina/team-learning-nlp/tree/master/GNN
前面文章:
- 图神经网络打卡task1
- Datawhale 图神经网络task2
- Datawhale 图神经网络task3
- Datawhale 图神经网络task4
- Datawhale 图神经网络task5
- Datawhale 图神经网络task6
超大规模数据集类的创建
在前面的学习中我们只接触了数据可全部储存于内存的数据集,这些数据集对应的数据集类在创建对象时就将所有数据都加载到内存。然而在一些应用场景中,数据集规模超级大,我们很难有足够大的内存完全存下所有数据。因此需要一个按需加载样本到内存的数据集类。在此上半节内容中,我们将学习为一个包含上千万个图样本的数据集构建一个数据集类。
Dataset
基类简介
在PyG中,我们通过继承torch_geometric.data.Dataset
基类来自定义一个按需加载样本到内存的数据集类。此基类与Torchvision的Dataset
类的概念密切相关,这与第6节中介绍的torch_geometric.data.InMemoryDataset
基类是一样的。
继承torch_geometric.data.InMemoryDataset
基类要实现的方法,继承此基类同样要实现,此外还需要实现以下方法:
len()
:返回数据集中的样本的数量。get()
:实现加载单个图的操作。注意:在内部,__getitem__()
返回通过调用get()
来获取Data
对象,并根据transform
参数对它们进行选择性转换。
下面让我们通过一个简化的例子看继承torch_geometric.data.Dataset
基类的规范:
import os.path as osp
import torch
from torch_geometric.data import Dataset, download_url
class MyOwnDataset(Dataset):
def __init__(self, root, transform=None, pre_transform=None):
super(MyOwnDataset, self).__init__(root, transform, pre_transform)
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data_1.pt', 'data_2.pt', ...]
def download(self):
# Download to `self.raw_dir`.
path = download_url(url, self.raw_dir)
...
def process(self):
i = 0
for raw_path in self.raw_paths:
# Read data from `raw_path`.
data = Data(...)
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
i += 1
def len(self):
return len(self.processed_file_names)
def get(self, idx):
data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
return data
其中,每个Data
对象在process()
方法中单独被保存,并在get()
中通过指定索引进行加载。
跳过download/process
对于无需下载数据集原文件的情况,我们不重写(override)download
方法即可跳过下载。对于无需对数据集做预处理的情况,我们不重写process
方法即可跳过预处理。
无需定义Dataset类
通过下面的方式,我们可以不用定义一个Dataset
类,而直接生成一个Dataloader
对象,直接用于训练:
from torch_geometric.data import Data, DataLoader
data_list = [Data(...), ..., Data(...)]
loader = DataLoader(data_list, batch_size=32)
我们也可以通过下面的方式将一个列表的Data
对象组成一个batch
:
from torch_geometric.data import Data, Batch
data_list = [Data(...), ..., Data(...)]
loader = Batch.from_data_list(data_list, batch_size=32)
图样本封装成批(BATCHING)与DataLoader
类
合并小图组成大图
图可以有任意数量的节点和边,它不是规整的数据结构,因此对图数据封装成批的操作与对图像与序列等数据封装成批的操作不同。PyTorch Geometric中采用的将多个图封装成批的方式是,将小图作为连通组件(connected component)的形式合并,构建一个大图。于是小图的邻接矩阵存储在大图邻接矩阵的对角线上。大图的邻接矩阵、属性矩阵、预测目标矩阵分别为:
KaTeX parse error: No such environment: split at position 8: \begin{̲s̲p̲l̲i̲t̲}̲\mathbf{A} = \b…
此方法有以下关键的优势:
-
依靠消息传递方案的GNN运算不需要被修改,因为消息仍然不能在属于不同图的两个节点之间交换。
-
没有额外的计算或内存的开销。例如,这个批处理程序的工作完全不需要对节点或边缘特征进行任何填充。请注意,邻接矩阵没有额外的内存开销,因为它们是以稀疏的方式保存的,只保留非零项,即边。
通过torch_geometric.data.DataLoader
类,多个小图被封装成一个大图。