torch_geometric使用指南 (作个人纪录)

建议用最新版本的torch_geometric,不同版本的API变动会比较大。

这个包最关键的一个类是MessagePassing
在这里插入图片描述

  • aggr: 信息传递的方式,默认是add。也就是neighbor的信息聚合是加在center node上的(详见GCN原文)。
  • flow: 信息传递的方向。这个要和后面的edge_index联合起来理解。默认为source_to_target

其他不做解释。

MessagePassing的forward参数是:
在这里插入图片描述
最重要的是这个edge_index,结合前面的flow参数,edge_index包含了你输入这个图的所有边的信息(start node、end node)。如图(黄色highlight的部分),输入的edge_index一般情况下是LongTensor,此时形状必须为[2, num_messages],第一维存放start node idx, 第二维存放end node idx。比方说:

## 假设图里面有三个节点,node index为 0,1,2
### 有向图:0->1, 1->2, 2->0 , 信息只能在有向边上传递(i.e., 只能0传给1, 1不能传给0)
edge_index = [[0,1,2]
			  [1,2,0]]
### 无向图,全连接,信息双向传递
edge_index = [[0,1,2,1,2,0]
			  [1,2,0,0,1,2]]

torch_geometric目前实现了很多近几年的GCN变体 (GAT,RGCN,etc.), 都是继承自MessagePassing, 只要理解了这个MessagePassing和他的edge_index,这些变体都可以直接调包用就可以了。

目前深度学习里面这种GCN都是空域GCN,所以默认是支持有向图的,可以简单地就把所谓的图神经网络理解为message passing就可以 (频域的话默认只能是无向图,不然理论上GCN无法成立,这一点了解就行,炼丹人不用care)

参考:
torch_geometric: MessagePassing

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值