DGL消息传递

官网解释

消息传递是实现GNN的一种通用框架和编程范式。它从聚合与更新的角度归纳总结了多种GNN模型的实现。
在这里插入图片描述

示例演示

在这里插入图片描述

图中:

  • 五个节点,索引是[0,1,2,3,4],接着为每个节点添加"n_heat"特征,[10,11,12,13,14]
  • 六条边,并且为边添加"e_feat"特征。
  • 根据上图使用DGL进行建图。

建图

g = dgl.graph(([0,1,2,3,4,3], [1,2,3,4,1,1]))
# 六条边,参数中第一个数组的是边的起点,第二个数组是边的终点。
g.edata['e_feat'] = torch.tensor([1000,2000,3000,4000,5000,6000])
# 根据上面添加的边的顺序,为这些边添加e_feat特征。六条边对应六个值。
g.ndata['n_feat'] = torch.tensor([10,11,12,13,14])
# 为五个节点按照索引依次添加n_feat特征

消息传递流程

  1. 信息函数
    消息函数接受一个参数 edges,这是一个 EdgeBatch 的实例, 在消息传递时,它被DGL在内部生成以表示一批边。 edges 有 src、 dst 和 data 共3个成员属性, 分别用于访问源节点、目标节点和边的特征。每条边的信息只能传递给边终点的节点。

演示:

每条边传递的消息:信息=该边起点的 n _ f e a t n\_feat n_feat+该边的 e _ f e a t e\_feat e_feat
m e s s a g e = n _ f e a t s t a r t + e _ f e a t message =n\_feat_{start}+e\_feat message=n_featstart+e_feat

print('origin edge feat')
def message_func(edges):
    #print(edges.data),张量,表示边的特征,shape=edge_num*e_feat_dim
    #print(edges.src),张量,表示每条边起点的特征,shape=edge_num*n_feat_dim
    #print(edges.dst),张量,表示每条边终点的特征,shape=edge_num*n_feat_dim
    #主要使用上面三类特征
    tmp = {'m':edges.data['e_feat']+edges.src['n_feat']}
    # 上述代码就是将每条边的e_feat和每条边的起点的特征n_feat相加,返回一个字典
    # 如果希望新的e_feat覆盖之前的e_feat,则key为'e_feat'。
    # 否则可以自命名,这里我以'm'为key(这里和字典的更新是一致的)
    return tmp
g.apply_edges(message_func)

运行结果:

  • 经过消息传递后,在每条边上存储了该边所蕴含的信息’m’。(这里是通过apply_edges将每条边的信息m更新到边上的。)以第一条边(0,1)为例,message = 节点0的n_feat + 该边的e_feat,即10+1000=1010。

  • 参数edges的内容
    在这里插入图片描述

  1. 聚合函数
    聚合函数 接受一个参数 nodes,这是一个 NodeBatch 的实例, 在消息传递时,它被DGL在内部生成以表示一批节点。 nodes 的成员属性 mailbox 可以用来访问节点收到的消息。 一些最常见的聚合操作包括 sum、max、min 等。
def reducer(nodes):
    print(nodes.nodes())    # nodes.nodes中包含当前batch中的有哪些节点
    print(nodes.mailbox)    # nodes.mailbox中包含有每个节点接受的消息
    tmp = {'h': torch.sum(nodes.mailbox['m'],dim=1)}
    # 对节点收到的信息求和
    return tmp
g.update_all(message_func, reducer)
#updata_all接受一个信息函数、聚合函数,然后更新每条边终点节点的信息。

运行结果:

  • 初始节点特征只有n_feat,当我们做完消息生成和聚合后,得到新的特征h,并且h特征更新到了节点的特征字典中。以节点1为例,节点1的信息来自三条边(0,1),(4,1),(3,1),三条边传递的信息分别是1010,5014,6013,将信息求和得到12037。
    在这里插入图片描述
    如果文章有错误,请大家指正,欢迎在讨论区留下您的想法。

完整代码

import dgl
import torch
import dgl.function as fn
import networkx as nx
import matplotlib.pyplot as plt

g = dgl.graph(([0,1,2,3,4,3], [1,2,3,4,1,1]))
g.edata['e_feat'] = torch.tensor([1000,2000,3000,4000,5000,6000])
g.ndata['n_feat'] = torch.tensor([10,11,12,13,14])
print('origin node feat')
print(g.ndata)
def message_func(edges):
    # print('-------')
    # print('edges.data:',edges.data)
    # print('edges.src:',edges.src)
    # print('edges.dst:',edges.dst)
    # print('--------')
    tmp = {'m':edges.data['e_feat']+edges.src['n_feat']}
    return tmp
def reducer(nodes):
    print('batch nodes: ',nodes.nodes())
    print('recive message: ',nodes.mailbox)
    print('------')
    tmp = {'h': torch.sum(nodes.mailbox['m'],dim=1)}
    return tmp
g.update_all(message_func, reducer)
#g.apply_edges(message_func)
print('updata node')
#print(g.ndata['h'])
print(g.ndata)
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值