DGL学习笔记03-消息传递机制
1 什么是消息传递?
什么是消息传递机制?首先来看下官方的解释(也可以去看论文)
对于这一节的话,我感觉如果没接触过Message Passing的人可能看了官方文档也不太容易理解它是什么东西,其实它的核心思想就是每个节点给它相邻的节点发消息。
如下图所示,我们先看下红色结点,它将自己的特征发送给相邻的两个蓝色结点,这个发送其实就是对应着官方文档的消息函数。而对于蓝色结点,它接收到红色结点的消息后进行聚合,比如可以进行sum, max, min, mean 等聚合操作。这些聚合操作就是对应官方文档的聚合函数。而蓝色结点得到邻居结点的发过来的消息后就可以更新自己的特征了,比如将所有的消息进行某种方式的聚合后再经过一层ReLu激活函数,结果作为自己的新特征,这就是更新函数的功能了。(需要注意的是,每个节点都具有发送消息、聚合消息、更新消息的功能,不要以为蓝色节点只能聚合消息)
举个简单的例子
我们有上面这样一张简单的图:三个结点,三条有向边,结点的特征是三维的,边的特征也是三维的。先尝试把图创建出来
import torch
import dgl
import dgl.function as fn
graph = dgl.graph(([0, 2, 2],
[1, 0, 1]))
node_feats = torch.tensor([[0, 1, 0],
[1, 1, 0],
[0, 1, 1]], dtype=torch.float)
edge_feats = torch.tensor([[1, 1, 0],
[0, 1, 0],
[1, 1, 0]
], dtype=torch.float)
graph.ndata['h'] = node_feats
graph.edata['e'] = edge_feats
print(graph)
接着我们在每个结点上使用消息函数和聚合函数,然后再输出每个结点的特征,看看它们是如何改变的。其中消息函数是将源节点的特征h乘上边的特征e,然后发送到目标结点的mailbox里面,也就是m。聚合函数就是取出mailbox里面的所有消息,然后进行累加,用来更新自己的h
print(graph.ndata['h'])
graph.update_all(fn.u_mul_e('h', 'e', 'm'), fn.sum('m', 'h'))
print(graph.ndata['h'])
可以看到,原来结点的特征h确实发生了改变,来仔细看看它是怎么改变的。上面那张图再拿下来看看。
我们看看0节点,只有2号节点给它发消息。2号节点将其特征[0,1,1]乘上边<2, 0>上的特征[0,1,0]得到[0,1,0],然后作为消息发给节点0。因为节点0只有一个节点2给它发消息,所以它的mailbox里只有一条消息[0,1,0],因此使用sum聚合后还是[0,1,0]。同样地,可以算出节点1和节点2更新后的结果为[0,2,0]和[0,0,0]
更多消息函数、聚合函数的使用请参考官方文档dgl.function