【DGL】dgl邻居节点采样器MultiLayerNeighborSampler

dgl.dataloading.neighbor.MultiLayerNeighborSamplericon-default.png?t=M276https://docs.dgl.ai/en/0.6.x/api/python/dgl.dataloading.html?highlight=multilayerneighborsampler#neighbor-sampler基于 `dgl.dataloading.dataloader.BlockSampler`

采样器,通过多层 GNN 的邻居采样建立节点表示的计算依赖。

该采样器将使每个节点从每种边缘类型的固定数量的邻居中收集消息。邻居是统一挑选的。

参数

从第一层开始,每个 GNN 层的每个边缘类型的邻居列表。

如果图是同构的,则每一层只需要一个整数。

如果为某一层提供值为 None ,则任何边类型的所有邻居都会包含在内。

如果为一层上的一种边类型提供值为 -1,则将包括该边类型的所有入站边。

  • replace (booldefault True) -- 是否进行替换采样
  • return_eids (booldefault False) – 是否返回 MFG 中消息传递所涉及的边 ID。如果为 True,边 ID 将存储为名为  dgl.EID 的边特征。

例子

为了在同构图上的一组节点 train_nid 上训练 3 层 GNN 进行节点分类,其中每个节点分别从第一、第二和第三层的 5、10、15 个邻居获取消息(假设后端是 PyTorch):

import dgl
import torch

sampler = dgl.dataloading.MultiLayerNeighborSampler([5, 10, 15])
collator = dgl.dataloading.NodeCollator(g, train_nid, sampler)
dataloader = torch.utils.data.DataLoader(
    collator.dataset, collate_fn=collator.collate,
    batch_size=1024, shuffle=True, drop_last=False, num_workers=4)
for blocks in dataloader:
    train_on(blocks)

如果在异构图上进行训练,并且希望每种边类型有不同数量的邻居,则应该提供一个dict列表。每个dict将指定每种边类型要拾取的邻居数。

sampler = dgl.dataloading.MultiLayerNeighborSampler([
    {('user', 'follows', 'user'): 5,
     ('user', 'plays', 'game'): 4,
     ('game', 'played-by', 'user'): 3}] * 3)

  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值