graph.update_all(message_func, reduce_func)
写在前面
实际上该函数有三个参数(消息函数,聚合函数,更新函数)。不过按照官方文档的介绍,更新函数完全可以使用张量计算,故而不建议使用更新函数。
一、message_func函数
下面有两个函数,他们完全等价。
如下函数的“feat”字段我们自定义的节点特征名,“m”字段可以自定义,它用来临时存放消息数据。
message_func=fn.copy_src(src="feat",out="m")
def message_func(edges):#将每条边起点节点的特征存入临时区
return {'m': edges.src['feat']}
这一步表示把数据存放到临时储存区,用“m”表示临时存放数据的名字
二、reduce_func函数
下面有两个函数,他们完全等价。
reduce_func=fn.sum(msg="m",out="feat")
def reduce_func(nodes):
feat = torch.sum(nodes.mailbox['m'], dim=1)
return {"feat": feat}
update_all函数会将message_func的节点特征处理, 方便reduce_func接收。
reduce_func接收到的是每个节点的父亲节点的特征信息,上面函数就求和每个节点特征并将数据返回到原来的“feat”上。
这里有个坑,先看下面代码
#首先构建一个图
import torch
import dgl
import networkx as nx
row = torch.tensor([0,0,0,0,0,1,2,3,4,5,1,2])
col = torch.tensor([1,2,3,4,5,0,0,0,0,0,2,1])
G_test = dgl.graph((row,col))
G_test.ndata["feat"] = torch.randn(6,10)
nx.draw(G_test.to_networkx(), with_labels=True)
def message_func(edges):
return {'m': edges.src['feat']}
def reduce_func(nodes):
print("m:",nodes.mailbox['m'].shape)
feat = torch.sum(nodes.mailbox['m'], dim=1)
return {"feat": feat}
print("feat:",G_test.ndata["feat"].shape)
G_test.update_all(message_func, reduce_func)
输出如下
feat: torch.Size([6, 10])
m: torch.Size([3, 1, 10])
m: torch.Size([2, 2, 10])
m: torch.Size([1, 5, 10])
这里遍历每个节点获取他们父亲节点的特征,但tensor维度由二维变成了三维。
原因是DGL加入了并行计算,观察上图节点3,4,5,不难发现他们父亲节点相同,于是这三个节点被合并,纳入了并行计算。
对应输出m: torch.Size([3, 1, 10]),[批量大小, 父亲节点数, 特征维度]。
现在可以思考一下m: torch.Size([2, 2, 10]),m: torch.Size([1, 5, 10])几个数子的含义