NeighborLoader
什么是NeighborLoader
是torch_geometric.loader中提供的方法。
作用:用于高效采样节点及其邻接节点,可以处理大规模数据,
通过邻居采样来生成迷你批次,用于训练。
例子:
Data:
0
/ \
1 2
\ /
3
假设每个节点都有一个特征向量。如下:
# 定义节点特征 (4 个节点,每个节点有 3 维特征)
x = torch.tensor([[1, 1, 1],
[2, 2, 2],
[3, 3, 3],
[4, 4, 4]], dtype=torch.float)
# 定义边列表 (每条边由两个节点索引组成)
edge_index = torch.tensor([[0, 0, 1, 2, 3, 3],
[1, 2, 0, 3, 1, 2]], dtype=torch.long)
创建NeighborLoader
# 创建 NeighborLoader
loader = NeighborLoader(
data,
num_neighbors=[2, 2], # 每一层的邻居采样数量
batch_size=2, # 每个批次的节点数量
shuffle=True,
)
打印结果:
# 查看生成的批次
for batch in loader:
print("Batch数据:")
print("Batch节点索引:", batch.batch)
print("Batch节点特征:\n", batch.x)
print("Batch边列表:\n", batch.edge_index)
print("Batch邻居采样节点:\n", batch.n_id)
break # 只打印一个批次的内容
Batch节点特征: tensor([[2., 2., 2.],
[1., 1., 1.],
[4., 4., 4.],
[3., 3., 3.]])
Batch边列表: EdgeIndex([[1, 2, 0, 3],
[0, 0, 1, 2]], sparse_size=(4, 4), nnz=4, sort_order=col)
Batch邻居采样节点: tensor([1, 0, 3, 2])
观测打印的数据:
Batch邻居采样节点: tensor([1, 0, 3, 2]) 这里是随机的,因为batch_size=2,会随机选取其中两个节点作为初始节点,后边开始两层的挑选,每一层都会挑选两个相邻节点,第二层会根据第一层挑选的相邻节点进行再一次的相邻节点的挑选。