dgl的消息传递机制

Google这个标题,我发现居然没人写,那我来浅浅的补充一下。
先放两段代码:

import networkx as nx
import dgl
import dgl.function as fn
import torch
import matplotlib.pyplot as plt

g = dgl.graph(([0, 0, 1, 2, 2, 3], [1, 2, 3, 3, 4, 4]))
g.ndata['x'] = torch.ones(5, 2)
print(g)
print(g.adj())
print(g.adj().to_dense())
nx.draw_networkx(dgl.to_networkx(g, node_attrs=['x']))
plt.show()

g.pull([0, 3, 4], fn.copy_u('x', 'm'), fn.sum('m', 'h'))
print(g.ndata)
g.pull([0, 1, 2, 3, 4], fn.copy_u('x', 'm'), fn.sum('m', 'h'))
print(g.ndata)
import networkx as nx
import dgl
import dgl.function as fn
import torch
import matplotlib.pyplot as plt

g = dgl.graph(([0, 0, 1, 2, 2, 3], [1, 2, 3, 3, 4, 4]))
g.ndata['x'] = torch.ones(5, 2)
print(g)
print(g.adj())
print(g.adj().to_dense())
nx.draw_networkx(dgl.to_networkx(g, node_attrs=['x']))
plt.show()

def message_func(edges):
    return {'x': edges.src['x']}
    
def reduce_func(nodes):
    return {'y': nodes.mailbox['x'].sum(1)}


g.update_all(message_func=message_func, reduce_func=reduce_func)
print(g.ndata)

首先要知道的是,dgl的图都是有向图,有向图的边就包括源节点和目标节点,即src指向dst。在实现消息传递机制时,一般会用第二种,update_all,其参数有两个,message_func可以通过图的edges获取到源节点的.src属性,得到要传递的消息,reduce_func在通过.mailbox收到消息后,在对应的目标节点进行计算并更新相应的属性。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值