- DGL提供了几个邻居采样类,这些类会生成需计算的节点在每一层计算时所需的依赖图。
- 最简单的邻居采样器是 MultiLayerFullNeighborSampler,它可获取节点的所有邻居。
- 要使用DGL提供的采样器,还需要将其与 NodeDataLoader 结合使用,后者可以以小批次的形式对一个节点的集合进行迭代。
例如,以下代码创建了一个PyTorch的 DataLoader,它分批迭代训练节点ID数组 train_nids
, 并将生成的子图列表放到GPU上。
import dgl
import dgl.nn as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
dataloader = dgl.dataloading.NodeDataLoader(
g, train_nids, sampler,
batch_size=1024,
shuffle=True,
drop_last=False,
num_workers=4)
对DataLoader进行迭代,将会创建一个特定图的列表,这些图表示每层的计算依赖。在DGL中称之为 blocks
。
input_nodes, output_nodes, blocks = next(iter(dataloader))
print(blocks)
(1)The iterator generates three items at a time.
(2)
input_nodes
describe the nodes needed to compute the representation of output_nodes
.
input_nodes
描述了计算【output_nodes
表示】所需的节点。
(3)
blocks
describe for each GNN layer which node representations are to be computed as output, which node representations are needed as input, and how does representation from the input nodes propagate to the output nodes.
blocks 描述了 对每个GNN层而言,哪些【节点表示】将被计算为输出,哪些【节点表示】需要作为输入,以及【表示】如何从输入节点的表示如何传播到输出节点。
每次迭代器生成三项:input_nodes, output_nodes, blocks
- output_nodes是batch中的目标顶点,即需要计算表示的顶点(最后一层的输出顶点)。
- input_nodes是计算这些顶点的表示所需要的输入顶点(第1层的输入顶点)。
- blocks是每一层GNN的输入图,第i个block由第i层的输入顶点和第i层的输出顶点(第i+1层的输入顶点)组成。
- input_nodes对应blocks[0].srcnodes(),output_nodes对应blocks[-1].dstnodes(),但input_nodes始终将output_nodes放在最前面,即input_nodes[:len(output_nodes)] == output_nodes。
- dataloader生成的block对顶点进行了重新编号,原id保存在名为
dgl.NID
的顶点属性中。