DGL 中的update_all函数 的详细理解

该博客详细介绍了如何使用DGL库进行图神经网络操作,包括节点特征向量的相加和乘法运算,以及通过`update_all`函数进行消息传递和聚合。在给定的图中,对每个节点的入度进行特征运算,最终得到每个节点的新特征向量。博客还解释了`update_all`函数中消息函数和聚合函数的工作原理,强调了在图神经网络中节点入度的重要性。
摘要由CSDN通过智能技术生成

所有讨论均以以下代码为基础进行:

 

import dgl
import torch as th

from dglfn  import dgl_fn
import dgl.function as fn 

# https://blog.csdn.net/CY19980216/article/details/110629996?

device=th.device("cuda" if th.cuda.is_available() else 'cpu')

u= th.tensor([0,1,2,3,4,3,4])
v= th.tensor([2,0,1,5,3,2,1])
g=dgl.graph((u,v)).to(device)
#print(g.edges())

g.ndata["value"]=th.ones(g.num_nodes(),3).to(device)  #每个值 都是3维向量
#print(g.ndata["value"])
#g.ndata["ft"]=th.tensor([0,1,2,3,4,5],dtype=th.float32).to(device)
g.edata["weights"]=th.ones(g.num_edges(),3).to(device)
dgl_fn.StartGraph(size=(12,10),title="DGL网络")


dgl_fn.AddGraph(221,g,"value",'weights')
g.apply_edges(fn.u_add_v('value','value','weights'))  #计算边的两个节点的相关运算,并存于边上。

#  计算方法: 源节点+目标节点值 , 例如: 0,1 的wegiths 为1, 2,得出边1 的值 为 3

#g.update_all(fn.u_add_v("value","value","m"),fn.sum("m","weights"))

dgl_fn.AddGraph(222,g,"value",'weights')

#print("the weights of edgs",g.edata["weights"])  #边的权重变为每个元素为2的矩阵

def updata_all_example(graph):
	# store the result in graph.ndata['ft']
	graph.update_all(fn.u_mul_e('value', 'weights', 'm'), fn.sum('m', 'ft'))    # torch.sum(x, 1)
	
	# Call update function outside of update_all
	final_ft = graph.ndata['ft'] 
	return final_ft

all=updata_all_example(g)
#  从源节点到它指向的边,相乘后的值 存放于m中,再对m和ft中的值 进行相加
#
#
#print(g.edges())
print(all)

dgl_fn.AddGraph(223,g,"value",'weights')
dgl_fn.ShowGraph()

在这张图中,有6个节点,总共有7条边。

如图所示:

 

 

为简化起见,预设原始的每节点的value特征是 [1.1,1], 原始的 每条边的权重 wegiths为[1,1,1].

在执行

g.apply_edges(fn.u_add_v('value','value','weights'))   #这个函数 没啥可说的,就是直接向量相加,存到边上

后,所有边分别的权重,即每条边的特征  weights 为  [2,2,2], 没错,是个向量。

记住上面的值 ,即边的weights 权重为 [2,2,2], 每个节点的value为[1,1,1]

最复杂的是下面这句

def updata_all_example(graph):

    # store the result in graph.ndata['ft']

    graph.update_all(fn.u_mul_e('value', 'weights', 'm'), fn.sum('m', 'ft'))    # torch.sum(x, 1)

    

    # Call update function outside of update_all

    final_ft = graph.ndata['ft'] 

    return final_ft

 

调用以上函数:

all=updata_all_example(g)

上面的操作是对每条边进行如下操作:

取源节点的value特征乘以与它相连的边的权重weghts, 这里的乘是向量的按位乘,即相同位置的分量相乘后作为同一个位置的结果,相乘后的值 存放于m中,再对m中的值 进行求和存放于目标节点的mailbox的ft特征中。

print(all)

怎么得到的上面的值 呢?

我们先看最前面的图。 先说一个结论,我们得到的值all是跟节点数一致的一个矩阵。行数即节点数量(这里是6个节点),第一行表示第一个节点0,其它类推。列数是特征维度,即3.

我们有5个节点: 0,1,2,3,4,5

需要详细考察每个节点的入度(是的,dgl是有向图,只算入度。update_all 只影响有入度的顶点,下面标的出度只是为了形象起见,并不影响值 )

0: 1 入  (1出)

1: 2 入 (1出)

2: 2入 (1出)

3:1 入(2出)

4:0入(2出)

5: 1入(0出)

再对照 上图中的值 ,有没有想明白是啥意思?

对: 只计算入度。 以节点1为例 ,每个入度得到的值是  [1,1,1] * [2,2,2] ->  [2,2,2], 总共有两个入度,作sum后,即是 [4,4,4], 对应上面结果中的第二行。

看明白没?

对,出度不会影响这个计算值 。所以,对于节点4, 它的sum后的结果就 [0, 0,0]

 

当然,我们上面的程序中没有更新函数。如果需要,可以直接更新目标节点的相应的特征值 。

 

有一点要注意: mailbox其实是一个抽象的概念,在消息函数运行后,每条边都有一个mailbox以及相关的特征,如上面的 mailbox 中的特征 m,  聚合函数就是把这个点的所有入度的mailbox中的m合并起来,合并的方法函数叫聚合函数。

这点刚开始有点费解。

比如顶点1, 它有两个入度,在执行完消息函数后,其实有两个mailbox ,里面都有一个m 特征, 聚合函数的作用就是把这些多个m 合并起来,然后存在当前节点 1的ft 特征中。

 

 

 

 

  • 4
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值