关于DGL中update_all函数的理解

graph.update_all(message_func, reduce_func)

写在前面

实际上该函数有三个参数(消息函数,聚合函数,更新函数)。不过按照官方文档的介绍,更新函数完全可以使用张量计算,故而不建议使用更新函数。

一、message_func函数

下面有两个函数,他们完全等价。

如下函数的“feat”字段我们自定义的节点特征名,“m”字段可以自定义,它用来临时存放消息数据。

message_func=fn.copy_src(src="feat",out="m")

def message_func(edges):#将每条边起点节点的特征存入临时区

    return {'m': edges.src['feat']}

 这一步表示把数据存放到临时储存区,用“m”表示临时存放数据的名字

二、reduce_func函数

下面有两个函数,他们完全等价。

reduce_func=fn.sum(msg="m",out="feat")

def reduce_func(nodes):
    feat = torch.sum(nodes.mailbox['m'], dim=1)
    return {"feat": feat}

update_all函数会将message_func的节点特征处理, 方便reduce_func接收。

reduce_func接收到的是每个节点的父亲节点的特征信息,上面函数就求和每个节点特征并将数据返回到原来的“feat”上。

这里有个坑,先看下面代码

#首先构建一个图
import torch
import dgl
import networkx as nx
row = torch.tensor([0,0,0,0,0,1,2,3,4,5,1,2])
col = torch.tensor([1,2,3,4,5,0,0,0,0,0,2,1])
G_test = dgl.graph((row,col))
G_test.ndata["feat"] = torch.randn(6,10)
nx.draw(G_test.to_networkx(), with_labels=True)

def message_func(edges):
    return {'m': edges.src['feat']}

def reduce_func(nodes):
    print("m:",nodes.mailbox['m'].shape)
    feat = torch.sum(nodes.mailbox['m'], dim=1)
    return {"feat": feat}

print("feat:",G_test.ndata["feat"].shape)
G_test.update_all(message_func, reduce_func)

输出如下

feat: torch.Size([6, 10])

m: torch.Size([3, 1, 10])

m: torch.Size([2, 2, 10])

m: torch.Size([1, 5, 10])

 这里遍历每个节点获取他们父亲节点的特征,但tensor维度由二维变成了三维。

原因是DGL加入了并行计算,观察上图节点3,4,5,不难发现他们父亲节点相同,于是这三个节点被合并,纳入了并行计算。

对应输出m: torch.Size([3, 1, 10]),[批量大小, 父亲节点数, 特征维度]。

现在可以思考一下m: torch.Size([2, 2, 10]),m: torch.Size([1, 5, 10])几个数子的含义

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值