建议用最新版本的torch_geometric
,不同版本的API变动会比较大。
这个包最关键的一个类是MessagePassing
:
aggr
: 信息传递的方式,默认是add。也就是neighbor的信息聚合是加在center node上的(详见GCN原文)。flow
: 信息传递的方向。这个要和后面的edge_index联合起来理解。默认为source_to_target
。
其他不做解释。
MessagePassing
的forward参数是:
最重要的是这个edge_index
,结合前面的flow
参数,edge_index
包含了你输入这个图的所有边的信息(start node、end node)。如图(黄色highlight的部分),输入的edge_index
一般情况下是LongTensor
,此时形状必须为[2, num_messages]
,第一维存放start node idx, 第二维存放end node idx。比方说:
## 假设图里面有三个节点,node index为 0,1,2
### 有向图:0->1, 1->2, 2->0 , 信息只能在有向边上传递(i.e., 只能0传给1, 1不能传给0)
edge_index = [[0,1,2]
[1,2,0]]
### 无向图,全连接,信息双向传递
edge_index = [[0,1,2,1,2,0]
[1,2,0,0,1,2]]
torch_geometric
目前实现了很多近几年的GCN变体 (GAT,RGCN,etc.), 都是继承自MessagePassing
, 只要理解了这个MessagePassing
和他的edge_index
,这些变体都可以直接调包用就可以了。
目前深度学习里面这种GCN都是空域GCN,所以默认是支持有向图的,可以简单地就把所谓的图神经网络理解为message passing就可以 (频域的话默认只能是无向图,不然理论上GCN无法成立,这一点了解就行,炼丹人不用care)