g.update_all(message_func=fn.copy_u(‘h‘, ‘m‘), reduce_func=fn.mean(‘m‘, ‘h_N‘))

本文详细介绍了DGL库中的消息函数copy_u和聚合函数mean在图神经网络中如何工作,通过实际例子展示了它们在节点特征传播和更新过程中的作用。特别关注了u_add_v函数的应用,以及在不同图结构下的行为。
摘要由CSDN通过智能技术生成

message_func=fn.copy_u(‘h’, ‘m’)是消息函数;reduce_func=fn.mean(‘m’, ‘h_N’))是聚合函数,聚合节点周围邻居节点的特征;update_all:触发更新。

import dgl
import dgl.function as fn
import torch
g = dgl.graph(([0, 1, 1, 3], [1, 2, 3, 2]))
g.ndata['x'] = torch.ones(4, 2)
g.update_all(fn.copy_u('x', 'm'), fn.mean('m', 'h'))
print(g.ndata['h'])
print(g)
g.update_all(fn.u_add_v('x', 'x', 'm'), fn.sum('m', 'h'))
print(g.ndata['h'])

消息函数:

消息函数接受一个参数 edges,这是一个 dgl.EdgeBatch 的实例, 在消息传递时,它被DGL在内部生成以表示一批边。edges有三个成员属性:src、dst和data,分别用于访问源节点、目标节点和边的特征。
在使用dgl定义的message passing函数时,经常使用如dgl.function.copy_u、dgl.function.copy_src和dgl.function.copy_e函数。这些函数的底层一般是用edges.src/edges.dst/edges.data实现。

copu_u

copy源节点的x特征到m,然后对m取均值得到节点特征h,由于这个图是0->1,1->2, 1->3, 3->2,因此0节点是没有消息传过来的。
copy_u将源节点的特征传递给邻居节点,保存在’m’中

# Builtin message function that computes message 
# using source node feature.

import dgl
message_func = dgl.function.copy_u('h', 'm')

# 这两个函数是等价的,访问源节点特征,并且存到m中
def message_func(edges):
    return {'m': edges.src['h']}  # 访问边对应的源节点特征
   

栗子:

g.update_all(fn.copy_u('x', 'm'), fn.mean('m', 'h'))
print(g.ndata['h'])

tensor([
[0., 0.],
[1., 1.],
[1., 1.],
[1., 1.]])

u_add_v

u_add_v是将源节点和目标节点的x特征相加得到m,然后对所有m求和得到h,同样0节点没有消息传来。

g.update_all(fn.u_add_v('x', 'x', 'm'), fn.sum('m', 'h'))
print(g.ndata['h'])

tensor([
[0., 0.],
[2., 2.],
[4., 4.],
[2., 2.]])

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

开开开心果儿

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值