图神经网络DGL库之消息传递

1 消息传递

1.1 图解

在这里插入图片描述
对上图的绿色框的函数进行解释:

  • 消息函数(message function):消息函数可以接收源节点的e.src.data,边的e.data以及目标节点的e.dst.data,之后将三者数据进行一些操作(例如加和),最终将数据存放在Mailbox
  • apply_nodes函数:可以使用目标节点的e.dst.data数据进行一些操作(例如e.dst.data+1)
  • 聚合函数(reduce function):可以获取目标节点以及Mailbox数据。将Mailbox数据提取出来,并清空Mailbox,之后更新目标节点。

对上图未提及的函数进行说明:

  • apply_edges函数:可以将消息函数操作后的数据附加在边上
  • update_all函数(更新):启用消息函数和聚合函数,即开始更新节点的流程(消息传递+消息聚合)。

1.2 语法格式

1.2.1 message函数

message函数采用单个参数edges(具有三个成员src,dst和data)分别用于访问源节点,目标节点和边的特征,如下:

def message_func(edges):

1.2.2 reduce函数

reduce函数采用单个参数节点nodes。 节点的成员属性mailbox可以用来访问节点收到的信息,然后做一些运算

  • 一些最常见的聚合运算包括sum,max,min等

如下:

def reducer(nodes):

1.2.3 update函数

调用节点计算的接口是update_all(),它在单个API调用里合并了消息生成、消息聚合和节点特征更新。update_all的参数是消息函数,reduce函数和更新函数

  • 更新函数是可选择参数,用户可以不使用,DGL不推荐在 update_all 中指定更新函数
  • 该函数相当于开始更新节点的流程(消息传递+消息聚合+节点特征更新)。

1.2.4 apply_nodes函数

语法格式:

DGLGraph.apply_nodes(func, v='__ALL__', ntype=None, inplace=False)

参数解释如下:

  • func:用于更新节点特征的函数。
  • v:默认是更新所有节点。
  • ntype:可选,节点类型名称。如果图中只有一个类型的节点,则可以省略。
  • 最后一个已弃用

1.2.5 apply_edges函数

DGLGraph.apply_edges(func, edges='__ALL__', etype=None, inplace=False)

参数解释如下:

  • func:用于生成新的边特征。
  • v:默认是更新所有边。
  • ntype:可选,边类型名称。如果图中只有一个类型的边,则可以省略。
  • 最后一个已弃用

2 具体例子

2.1 建图

示例图如下:
在这里插入图片描述
建图代码如下:

import dgl
import torch

g = dgl.graph(([0,1,2], [1,2,0]))
g.edata['e_feat'] = torch.tensor([2000,3000,4000])
g.ndata['n_feat'] = torch.tensor([20,21,22])

2.2 消息传递

2.2.1 函数构造

该消息传递方式将源节点的特征和边的特征进行聚合

def message_func(edges):
    # 常规属性操作如下:
    # print('edges.data:',edges.data)
    # print('edges.src:',edges.src)
    # print('edges.dst:',edges.dst)
    tmp = {'m':edges.data['e_feat']+edges.src['n_feat']}
    return tmp

2.2.2 边更新

将消息传递函数应用在边上,更新边的特征,代码如下:

import dgl
import torch

g = dgl.graph(([0,1,2], [1,2,0]))
g.edata['e_feat'] = torch.tensor([2000,3000,4000])
g.ndata['n_feat'] = torch.tensor([20,21,22])
print('origin edge feat')
print(g.edata)
print('-------------------------------')
def message_func(edges):
    # 常规属性操作如下:
    # print('edges.data:',edges.data)
    # print('edges.src:',edges.src)
    # print('edges.dst:',edges.dst)
    tmp = {'m':edges.data['e_feat']+edges.src['n_feat']}
    return tmp

g.apply_edges(message_func)
print('updata edge')
print(g.edata)

运行时,以(0,1)边为例,m=节点0的n_feat + 该边的e_feat,即20+2000=2020,以此类推,结果如下:
在这里插入图片描述

2.2.3 节点更新

import dgl
import torch

# 创建图
g = dgl.graph(([0, 1, 2], [1, 2, 0]))
g.edata['e_feat'] = torch.tensor([2000, 3000, 4000])  # 边特征
g.ndata['n_feat'] = torch.tensor([20, 21, 22])  # 节点特征
print('Original node features:')
print(g.ndata)
print('-------------------------------')

g.apply_nodes(lambda nodes: {'n_feat': nodes.data['n_feat'] * 2})

# 打印更新后的节点特征
print('Updated node features:')
print(g.ndata)

将节点信息×2,结果如图所示:
在这里插入图片描述

2.2.4 消息聚合

1 未使用更新函数
import dgl
import torch

g = dgl.graph(([0,1,2], [1,2,0]))
g.edata['e_feat'] = torch.tensor([2000,3000,4000])
g.ndata['n_feat'] = torch.tensor([20,21,22])
print('origin node feat')
print(g.ndata)
print('-------------------------------')
def message_func(edges):
    # 常规属性操作如下:
    # print('edges.data:',edges.data)
    # print('edges.src:',edges.src)
    # print('edges.dst:',edges.dst)
    tmp = {'m':edges.data['e_feat']+edges.src['n_feat']}
    return tmp
def reducer(nodes):
    # DGL中,批次中的节点是按照图的划分和计算需求确定的
    print('batch nodes: ',nodes.nodes())
    # nodes.mailbox 只包含在 message_func 中生成并发送到节点的消息
    print('mailbox: ',nodes.mailbox)
    print('--------------------------')
    # 每一行进行求和,目的是将数据转成列表格式
    tmp = {'h': torch.sum(nodes.mailbox['m'],dim=1)}
    return tmp
g.update_all(message_func, reducer)
print('updata node')
print(g.ndata)
print('edge')
print(g.edata)

经过了消息生成、消息聚合和节点特征更新过程,将新特征h更新到节点的特征字典中。

  • 注意:这个过程并不会把特征m加到边的特征字典中

在这里插入图片描述

2 使用更新函数

很少这么用,不建议

import dgl
import torch

# 创建图
g = dgl.graph(([0, 1, 2], [1, 2, 0]))
g.edata['e_feat'] = torch.tensor([2000, 3000, 4000])  # 边特征
g.ndata['n_feat'] = torch.tensor([20, 21, 22])  # 节点特征
print('Original node features:')
print(g.ndata)
print('-------------------------------')

# 消息传递函数
def message_func(edges):
    # 计算消息:边特征 + 源节点特征
    return {'m': edges.data['e_feat'] + edges.src['n_feat']}

# 聚合函数
def reducer(nodes):
    # 打印批次节点和邮件箱内容
    print('Batch nodes: ', nodes.nodes())
    print('Mailbox: ', nodes.mailbox)
    print('--------------------------')
    # 对消息进行求和
    return {'h': torch.sum(nodes.mailbox['m'], dim=1)}

# 更新节点特征的函数
def update_node_features(nodes):
    # 使用聚合后的特征更新节点特征
    # nodes.data['h'] 是聚合后的消息
    # nodes.data['n_feat'] 是节点的原始特征
    updated_feat = nodes.data['n_feat'] + nodes.data['h']
    return {'h': updated_feat}

# 执行消息传递和聚合
g.update_all(message_func, reducer)

# 在消息传递后,使用 apply_nodes 更新节点特征
g.update_all(message_func, reducer,update_node_features) # 获取聚合结果

# 打印更新后的节点特征
print('Updated node features:')
print(g.ndata)
print('Edge features:')
print(g.edata)

结果如图所示:
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值