图神经网络框架DGL学习 103——信息传递 (Message Passing Tutorial)

在图神经网络中,信息的传递和特征的转变,用户可以自定义的。当然在DGL中,也有高级别的API供调用。
现在来看一个网页排名简单的模型。每一个节点都有相同的PV值,PV=0.01, 每一个节点首先会均匀分散自己的PV值给周围的节点。各节点新的PV值等于周围节点的聚合,同时受到阻尼因子的调节。因此,每一次迭代,节点PV值的变化如下:
在这里插入图片描述
其中,d为阻尼因子,N为节点数。N为节点的邻接节点,D为节点的输出度(deg)

目的:查看10次迭代之后,各节点的PV值

首先,导入需要的模块

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import dgl
import torch
import dgl.function as fn

一、构建图

N = 100 #节点数量
DAMP=0.85 #公式中的d
K = 10 #迭代次数
g = nx.nx.erdos_renyi_graph(N, 0.1) #随机erdos_renyi图
g = dgl.DGLGraph(g)
#可视化
nx.draw(g.to_networkx(), with_labels=True, nodes_size=50)
plt.show()

在这里插入图片描述
初始化节点PV值和输出度:

g.ndata['PV'] = torch.ones(N)/N #初始数据每个节点均为0.01
g.ndata['deg'] = g.out_degrees(g.nodes()).float() #每个节点输出度
print(g.ndata)

在这里插入图片描述

二、信息传递、聚合的四种方法

(1)完全用户自定义模式

分为四步:

1. 节点的输出PV值函数

def pagerank_message_func(edges):
    '''
    信息输出函数
    :param edges: 图g的边对象,多条边(边对象具有src, dst, data三个属性, 分别代表边的起始节点,终止节点,边特征)
                这里只对输出节点进行迭代PV值
    :return:
    '''
    return {'PV': edges.src['PV'] / edges.src['deg']}

2. 节点PV信息聚合函数

def pagerank_reduce_func(nodes):
    '''
    信息的聚合(衰减)函数
    :param nodes: 输入的节点对象
    :return:
    '''
    megs = torch.sum(nodes.mailbox['PV'], dim=1)
    pv = (1-DAMP)/N + DAMP*megs
    return {'PV': pv}

3. 节点输出函数、信息聚合函数导入图

#对信息传递函数和信息聚合(衰减)函数进行注册在g网络中。
g.register_message_func(pagerank_message_func)
g.register_reduce_func(pagerank_reduce_func)

4. 向前传播

#信息正向传播
def pagerank_naive(g):
    #阶段一:信息沿着节点发出所有信息
    for u, v in zip(*g.edges()):
        g.send((u,v)) #输入的是边
    #阶段二:接收信息,计算新的PV值
    for v in g.nodes():
        g.recv(v) #输入的是终止节点

(2) 适合大图的批量处理方法

与(1)中的方法类似,也需要先定义好信息传播函数、信息聚合函数,并将这两个函数传入图g中。有差别的是第四步。将第四步替换成以下内容:

ef pagerank_batch(g):
    g.send(g.edges())
    g.recv(g.nodes())

批量出的原理:
“您可能想知道是否有可能在所有节点上并行执行reduce,因为每个节点可能有不同数量的传入消息,
并且您无法真正将不同长度的张量真正“堆叠”在一起。 通常,DGL通过按传入消息的数量对节点进行分组并为每个组调用reduce函数来解决该问题。”

(3)使用dgl中level 2的API

与(1)类似,前三步,万全相同,只是修改第四步,如下:

#DGL中可以使用更高级的API更新图 (level-2 APIs)
def pagerank_level2(g):
    g.update_all()

(4) 从头到尾直接使用更高效的内置函数

#还有一种更有效的方法,使用dgl的内置函数。该方法执行起来更快。
def pagerank_builtin(g):
    '''
        1. fn.copy_src: 构建信息输出函数,将起始节点的信息传播出去
        2. fn.sum:信息聚集
    :param g:  图对象
    :return:
    '''
    g.ndata['PV'] = g.ndata['PV']/g.ndata['deg']
    g.update_all(message_func=fn.copy_src(src='PV', out='m'),
                 reduce_func=fn.sum(msg='m', out='m_sum'))
    g.ndata['PV'] = (1-DAMP)/N + DAMP * g.ndata['m_sum']

三、进行N次迭代

#进行K次迭代
for k in range(K):
    pagerank_builtin(g)
print(g.ndata['PV'])

在这里插入图片描述

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值