Google这个标题,我发现居然没人写,那我来浅浅的补充一下。
先放两段代码:
import networkx as nx
import dgl
import dgl.function as fn
import torch
import matplotlib.pyplot as plt
g = dgl.graph(([0, 0, 1, 2, 2, 3], [1, 2, 3, 3, 4, 4]))
g.ndata['x'] = torch.ones(5, 2)
print(g)
print(g.adj())
print(g.adj().to_dense())
nx.draw_networkx(dgl.to_networkx(g, node_attrs=['x']))
plt.show()
g.pull([0, 3, 4], fn.copy_u('x', 'm'), fn.sum('m', 'h'))
print(g.ndata)
g.pull([0, 1, 2, 3, 4], fn.copy_u('x', 'm'), fn.sum('m', 'h'))
print(g.ndata)
import networkx as nx
import dgl
import dgl.function as fn
import torch
import matplotlib.pyplot as plt
g = dgl.graph(([0, 0, 1, 2, 2, 3], [1, 2, 3, 3, 4, 4]))
g.ndata['x'] = torch.ones(5, 2)
print(g)
print(g.adj())
print(g.adj().to_dense())
nx.draw_networkx(dgl.to_networkx(g, node_attrs=['x']))
plt.show()
def message_func(edges):
return {'x': edges.src['x']}
def reduce_func(nodes):
return {'y': nodes.mailbox['x'].sum(1)}
g.update_all(message_func=message_func, reduce_func=reduce_func)
print(g.ndata)
首先要知道的是,dgl的图都是有向图,有向图的边就包括源节点和目标节点,即src指向dst。在实现消息传递机制时,一般会用第二种,update_all,其参数有两个,message_func可以通过图的edges获取到源节点的.src属性,得到要传递的消息,reduce_func在通过.mailbox收到消息后,在对应的目标节点进行计算并更新相应的属性。