dgl.dataloading.neighbor.MultiLayerNeighborSamplerhttps://docs.dgl.ai/en/0.6.x/api/python/dgl.dataloading.html?highlight=multilayerneighborsampler#neighbor-sampler基于 `dgl.dataloading.dataloader.BlockSampler`
采样器,通过多层 GNN 的邻居采样建立节点表示的计算依赖。
该采样器将使每个节点从每种边缘类型的固定数量的邻居中收集消息。邻居是统一挑选的。
参数
从第一层开始,每个 GNN 层的每个边缘类型的邻居列表。
如果图是同构的,则每一层只需要一个整数。
如果为某一层提供值为 None ,则任何边类型的所有邻居都会包含在内。
如果为一层上的一种边类型提供值为 -1,则将包括该边类型的所有入站边。
- replace (bool, default True) -- 是否进行替换采样
- return_eids (bool, default False) – 是否返回 MFG 中消息传递所涉及的边 ID。如果为 True,边 ID 将存储为名为
dgl.EID
的边特征。
例子
为了在同构图上的一组节点 train_nid
上训练 3 层 GNN 进行节点分类,其中每个节点分别从第一、第二和第三层的 5、10、15 个邻居获取消息(假设后端是 PyTorch):
import dgl
import torch
sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10, 15])
collator = dgl.dataloading.NodeCollator(g, train_nid, sampler)
dataloader = torch.utils.data.DataLoader(
collator.dataset, collate_fn=collator.collate,
batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
for blocks in dataloader:
train_on(blocks)
如果在异构图上进行训练,并且希望每种边类型有不同数量的邻居,则应该提供一个dict列表。每个dict将指定每种边类型要拾取的邻居数。
sampler = dgl.dataloading.MultiLayerNeighborSampler([
{('user', 'follows', 'user'): 5,
('user', 'plays', 'game'): 4,
('game', 'played-by', 'user'): 3}] * 3)