DGL学习笔记03-消息传递机制

DGL学习笔记03-消息传递机制

1 什么是消息传递?

什么是消息传递机制?首先来看下官方的解释(也可以去看论文)
在这里插入图片描述

对于这一节的话,我感觉如果没接触过Message Passing的人可能看了官方文档也不太容易理解它是什么东西,其实它的核心思想就是每个节点给它相邻的节点发消息。
如下图所示,我们先看下红色结点,它将自己的特征发送给相邻的两个蓝色结点,这个发送其实就是对应着官方文档的消息函数。而对于蓝色结点,它接收到红色结点的消息后进行聚合,比如可以进行sum, max, min, mean 等聚合操作。这些聚合操作就是对应官方文档的聚合函数。而蓝色结点得到邻居结点的发过来的消息后就可以更新自己的特征了,比如将所有的消息进行某种方式的聚合后再经过一层ReLu激活函数,结果作为自己的新特征,这就是更新函数的功能了。(需要注意的是,每个节点都具有发送消息、聚合消息、更新消息的功能,不要以为蓝色节点只能聚合消息)
在这里插入图片描述
在这里插入图片描述

举个简单的例子

在这里插入图片描述

我们有上面这样一张简单的图:三个结点,三条有向边,结点的特征是三维的,边的特征也是三维的。先尝试把图创建出来

import torch
import dgl
import dgl.function as fn


graph = dgl.graph(([0, 2, 2],
                   [1, 0, 1]))

node_feats = torch.tensor([[0, 1, 0],
                           [1, 1, 0],
                           [0, 1, 1]], dtype=torch.float)

edge_feats = torch.tensor([[1, 1, 0],
                           [0, 1, 0],
                           [1, 1, 0]
                           ], dtype=torch.float)
graph.ndata['h'] = node_feats

graph.edata['e'] = edge_feats

print(graph)

在这里插入图片描述
接着我们在每个结点上使用消息函数和聚合函数,然后再输出每个结点的特征,看看它们是如何改变的。其中消息函数是将源节点的特征h乘上边的特征e,然后发送到目标结点的mailbox里面,也就是m。聚合函数就是取出mailbox里面的所有消息,然后进行累加,用来更新自己的h

print(graph.ndata['h'])
graph.update_all(fn.u_mul_e('h', 'e', 'm'), fn.sum('m', 'h'))
print(graph.ndata['h'])

在这里插入图片描述
可以看到,原来结点的特征h确实发生了改变,来仔细看看它是怎么改变的。上面那张图再拿下来看看。
在这里插入图片描述
我们看看0节点,只有2号节点给它发消息。2号节点将其特征[0,1,1]乘上边<2, 0>上的特征[0,1,0]得到[0,1,0],然后作为消息发给节点0。因为节点0只有一个节点2给它发消息,所以它的mailbox里只有一条消息[0,1,0],因此使用sum聚合后还是[0,1,0]。同样地,可以算出节点1和节点2更新后的结果为[0,2,0]和[0,0,0]

更多消息函数、聚合函数的使用请参考官方文档dgl.function

  • 5
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值