第2章:消息传递范式

第2章:消息传递范式

消息传递是实现GNN的一种通用框架和编程范式。它从聚合与更新的角度归纳总结了多种GNN模型的实现。

消息传递范式

假设节点 𝑣 𝑣 v 上的的特征为 𝑥 𝑣 ∈ R 𝑑 1 𝑥_𝑣 \in ℝ^{𝑑1} xvRd1,边 ( 𝑢 , 𝑣 ) (𝑢,𝑣) (u,v) 上的特征为 𝑤 𝑒 ∈ R 𝑑 2 𝑤_𝑒\in ℝ^{𝑑2} weRd2。 消息传递范式 定义了以下逐节点和边上的计算:
边上计算:

点上计算:

在上面的等式中, Font metrics not found for font: . 是定义在每条边上的消息函数,它通过将边上特征与其两端节点的特征相结合来生成消息。 聚合函数 Font metrics not found for font: . 会聚合节点接受到的消息。 更新函数 Font metrics not found for font: . 会结合聚合后的消息和节点本身的特征来更新节点的特征。

内置函数和消息传递API

在DGL中,消息函数 接受一个参数 edges,这是一个 EdgeBatch 的实例, 在消息传递时,它被DGL在内部生成以表示一批边。 edgessrcdstdata 共3个成员属性, 分别用于访问源节点、目标节点和边的特征。

聚合函数 接受一个参数 nodes,这是一个 NodeBatch 的实例, 在消息传递时,它被DGL在内部生成以表示一批节点。 nodes 的成员属性 mailbox 可以用来访问节点收到的消息。 一些最常见的聚合操作包括 summaxmin 等。

更新函数 接受一个如上所述的参数 nodes。此函数对 聚合函数 的聚合结果进行操作, 通常在消息传递的最后一步将其与节点的特征相结合,并将输出作为节点的新特征。

DGL在命名空间 dgl.function 中实现了常用的消息函数和聚合函数作为 内置函数。 一般来说,DGL建议 尽可能 使用内置函数,因为它们经过了大量优化,并且可以自动处理维度广播。

如果用户的消息传递函数无法用内置函数实现,则可以实现自己的消息或聚合函数(也称为 用户定义函数 )。

内置消息函数可以是一元函数或二元函数。对于一元函数,DGL支持 copy 函数。对于二元函数, DGL现在支持 addsubmuldivdot 函数。消息的内置函数的命名约定是 u 表示 节点, v 表示 目标 节点,e 表示 。这些函数的参数是字符串,指示相应节点和边的输入和输出特征字段名。 关于内置函数的列表,请参见 DGL Built-in Function。例如,要对源节点的 hu 特征和目标节点的 hv 特征求和, 然后将结果保存在边的 he 特征上,用户可以使用内置函数 dgl.function.u_add_v('hu', 'hv', 'he')。 而以下用户定义消息函数与此内置函数等价。

def message_func(edges):
     return {'he': edges.src['hu'] + edges.dst['hv']}

DGL支持内置的聚合函数 summaxminmean 操作。 聚合函数通常有两个参数,它们的类型都是字符串。一个用于指定 mailbox 中的字段名,一个用于指示目标节点特征的字段名, 例如, dgl.function.sum('m', 'h')等价于如下所示的对接收到消息求和的用户定义函数:

import torch
def reduce_func(nodes):
     return {'h': torch.sum(nodes.mailbox['m'], dim=1)}

关于用户定义函数的进阶用法,参见 User-defined Functions

在DGL中,也可以在不涉及消息传递的情况下,通过 apply_edges() 单独调用逐边计算。 apply_edges() 的参数是一个消息函数。并且在默认情况下,这个接口将更新所有的边。例如:

import dgl.function as fn
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))

对于消息传递, update_all() 是一个高级API。它在单个API调用里合并了消息生成、 消息聚合和节点特征更新,这为从整体上进行系统优化提供了空间。

update_all() 的参数是一个消息函数、一个聚合函数和一个更新函数。 更新函数是一个可选择的参数,用户也可以不使用它,而是在 update_all 执行完后直接对节点特征进行操作。 由于更新函数通常可以用纯张量操作实现,所以DGL不推荐在 update_all 中指定更新函数。例如:

def update_all_example(graph):
    # 在graph.ndata['ft']中存储结果
    graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                     fn.sum('m', 'ft'))
    # 在update_all外调用更新函数
    final_ft = graph.ndata['ft'] * 2
    return final_ft

此调用通过将源节点特征 ft 与边特征 a 相乘生成消息 m, 然后对所有消息求和来更新节点特征 ft,再将 ft 乘以2得到最终结果 final_ft

调用后,中间消息 m 将被清除。上述函数的数学公式为:

编写高效的消息传递代码

DGL优化了消息传递的内存消耗和计算速度。利用这些优化的一个常见实践是通过基于内置函数的 update_all() 来开发消息传递功能。

除此之外,考虑到某些图边的数量远远大于节点的数量,DGL建议避免不必要的从点到边的内存拷贝。对于某些情况,比如 GATConv,计算必须在边上保存消息, 那么用户就需要调用基于内置函数的 apply_edges()。有时边上的消息可能是高维的,这会非常消耗内存。 DGL建议用户尽量减少边的特征维数。

下面是一个如何通过对节点特征降维来减少消息维度的示例。该做法执行以下操作:拼接 源 源 节点和 目标 目标 目标 节点特征, 然后应用一个线性层,即 𝑊 × ( 𝑢 ∣ ∣ 𝑣 ) 𝑊×(𝑢||𝑣) W×(u∣∣v) 源 源 节点和 目标 目标 目标 节点特征维数较高,而线性层输出维数较低。 一个直截了当的实现方式如下:

import torch
import torch.nn as nn

linear = nn.Parameter(torch.FloatTensor(size=(node_feat_dim * 2, out_dim)))
def concat_message_function(edges):
     return {'cat_feat': torch.cat([edges.src['feat'], edges.dst['feat']], dim=1)}
g.apply_edges(concat_message_function)
g.edata['out'] = g.edata['cat_feat'] @ linear

建议的实现是将线性操作分成两部分,一个应用于 源 节点特征,另一个应用于 目标 节点特征。 在最后一个阶段,在边上将以上两部分线性操作的结果相加,即执行 𝑊 𝑙 × 𝑢 + 𝑊 𝑟 × 𝑣 𝑊_𝑙×𝑢+𝑊_𝑟×𝑣 Wl×u+Wr×v, 因为 𝑊 × ( 𝑢 ∣ ∣ 𝑣 ) = 𝑊 𝑙 × 𝑢 + 𝑊 𝑟 × 𝑣 𝑊×(𝑢||𝑣)=𝑊_𝑙×𝑢+𝑊_𝑟×𝑣 W×(u∣∣v)=Wl×u+Wr×v,其中 𝑊 𝑙 𝑊_𝑙 Wl 𝑊 𝑟 𝑊_𝑟 Wr 分别是矩阵 𝑊 𝑊 W 的左半部分和右半部分:

import dgl.function as fn

linear_src = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))
linear_dst = nn.Parameter(torch.FloatTensor(size=(node_feat_dim, out_dim)))
out_src = g.ndata['feat'] @ linear_src
out_dst = g.ndata['feat'] @ linear_dst
g.srcdata.update({'out_src': out_src})
g.dstdata.update({'out_dst': out_dst})
g.apply_edges(fn.u_add_v('out_src', 'out_dst', 'out'))

以上两个实现在数学上是等价的。后一种方法效率高得多,因为不需要在边上保存feat_srcfeat_dst, 从内存角度来说是高效的。另外,加法可以通过DGL的内置函数 u_add_v 进行优化,从而进一步加快计算速度并节省内存占用。

在图的一部分上进行消息传递

如果用户只想更新图中的部分节点,可以先通过想要囊括的节点编号创建一个子图, 然后在子图上调用 update_all() 方法。例如:

nid = [0, 2, 3, 6, 7, 9]
sg = g.subgraph(nid)
sg.update_all(message_func, reduce_func, apply_node_func)

这是小批量训练中的常见用法。更多详细用法请参考用户指南 第6章:在大图上的随机(批次)训练

在异构图上进行消息传递

异构图(参考用户指南 1.5 异构图 )是包含不同类型的节点和边的图。 不同类型的节点和边常常具有不同类型的属性。这些属性旨在刻画每一种节点和边的特征。在使用图神经网络时,根据其复杂性, 可能需要使用不同维度的表示来对不同类型的节点和边进行建模。

异构图上的消息传递可以分为两个部分:

  • 对每个关系计算和聚合消息。
  • 对每个结点聚合来自不同关系的消息。

在DGL中,对异构图进行消息传递的接口是 multi_update_all()multi_update_all()接受一个字典。这个字典的每一个键值对里,键是一种关系, 值是这种关系对应 update_all() 的参数。 multi_update_all() 还接受一个字符串来表示跨类型整合函数,来指定整合不同关系聚合结果的方式。 这个整合方式可以是 summinmaxmeanstack 中的一个。以下是一个例子:

import dgl.function as fn

for c_etype in G.canonical_etypes:
    srctype, etype, dsttype = c_etype
    Wh = self.weight[etype](feat_dict[srctype])
    # 把它存在图中用来做消息传递
    G.nodes[srctype].data['Wh_%s' % etype] = Wh
    # 指定每个关系的消息传递函数:(message_func, reduce_func).
    # 注意结果保存在同一个目标特征“h”,说明聚合是逐类进行的。
    funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h'))
# 将每个类型消息聚合的结果相加。
G.multi_update_all(funcs, 'sum')
# 返回更新过的节点特征字典
return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

发呆的比目鱼

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值