def _region_aggregate(self, feats, edge_dict):
"""
feats: node features [n,feature_dim]
edge_dict: adjacency list [[第一个node 邻接的node的编号],[],[]] 也就是node2node
"""
N = feats.size()[0]
pooled_feats = torch.stack([torch.mean(feats[edge_dict[i]], dim=0) for i in range(N)])
return pooled_feats
GNN汇集邻域信息代码
最新推荐文章于 2022-04-30 18:16:32 发布